def log_background_task(_, config, queue): """ Use background task to update logger files using ray.tune `UnifiedLogger` """ # Configure ray.tune loggers from ray.tune.logger import UnifiedLogger logger = UnifiedLogger(config=config, logdir=config["logdir"], loggers=config.get("loggers", None)) logger.last_timestamp = logger.start_timestamp = time.time() # Wait for results results = queue.get() while results is not None: # Update ray.tune fields timestamp = results["timestamp"] results["config"] = config results["experiment_id"] = config["experiment_id"] results["experiment_tag"] = config["experiment_tag"] results["training_iteration"] = results.pop("epoch") results["neg_mean_loss"] = results["mean_loss"] results["timesteps_total"] = results.get("timestep", 0) results["timestamp"] = int(timestamp) results["date"] = datetime.fromtimestamp(timestamp).strftime( "%Y-%m-%d_%H-%M-%S") results["time_this_iter_s"] = timestamp - logger.last_timestamp results["time_total_s"] = timestamp - logger.start_timestamp logger.on_result(results) logger.last_timestamp = timestamp # Wait for next results results = queue.get() logger.flush() logger.close()
def run_trial(config): """ Run a single trial configuration """ # Configure ray.tune loggers from ray.tune.logger import UnifiedLogger logger = UnifiedLogger(config=config, logdir=config["logdir"], loggers=config.get("loggers", None)) logger.last_timestamp = logger.start_timestamp = time.time() result = imagenet_run.run(config=config, logger=partial(log_results, logger, config), on_checkpoint=partial(save_checkpoint, config)) logger.flush() logger.close() return result
class Trainable: """Abstract class for trainable models, functions, etc. A call to ``train()`` on a trainable will execute one logical iteration of training. As a rule of thumb, the execution time of one train call should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes). Calling ``save()`` should save the training state of a trainable to disk, and ``restore(path)`` should restore a trainable to the given state. Generally you only need to implement ``setup``, ``step``, ``save_checkpoint``, and ``load_checkpoint`` when subclassing Trainable. Other implementation methods that may be helpful to override are ``log_result``, ``reset_config``, ``cleanup``, and ``_export_model``. When using Tune, Tune will convert this class into a Ray actor, which runs on a separate process. Tune will also change the current working directory of this process to ``self.logdir``. """ def __init__(self, config=None, logger_creator=None): """Initialize an Trainable. Sets up logging and points ``self.logdir`` to a directory in which training outputs should be placed. Subclasses should prefer defining ``build()`` instead of overriding ``__init__()`` directly. Args: config (dict): Trainable-specific configuration data. By default will be saved as ``self.config``. logger_creator (func): Function that creates a ray.tune.Logger object. If unspecified, a default logger is created. """ self._experiment_id = uuid.uuid4().hex self.config = config or {} trial_info = self.config.pop(TRIAL_INFO, None) self._result_logger = self._logdir = None self._create_logger(self.config, logger_creator) self._stdout_context = self._stdout_fp = self._stdout_stream = None self._stderr_context = self._stderr_fp = self._stderr_stream = None self._stderr_logging_handler = None stdout_file = self.config.pop(STDOUT_FILE, None) stderr_file = self.config.pop(STDERR_FILE, None) self._open_logfiles(stdout_file, stderr_file) self._iteration = 0 self._time_total = 0.0 self._timesteps_total = None self._episodes_total = None self._time_since_restore = 0.0 self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = False self._trial_info = trial_info start_time = time.time() self.setup(copy.deepcopy(self.config)) setup_time = time.time() - start_time if setup_time > SETUP_TIME_THRESHOLD: logger.info("Trainable.setup took {:.3f} seconds. If your " "trainable is slow to initialize, consider setting " "reuse_actors=True to reduce actor creation " "overheads.".format(setup_time)) self._local_ip = self.get_current_ip() log_sys_usage = self.config.get("log_sys_usage", False) self._monitor = UtilMonitor(start=log_sys_usage) @classmethod def default_resource_request(cls, config): """Provides a static resource requirement for the given configuration. This can be overridden by sub-classes to set the correct trial resource allocation, so the user does not need to. .. code-block:: python @classmethod def default_resource_request(cls, config): return Resources( cpu=0, gpu=0, extra_cpu=config["workers"], extra_gpu=int(config["use_gpu"]) * config["workers"]) Returns: Resources: A Resources object consumed by Tune for queueing. """ return None @classmethod def resource_help(cls, config): """Returns a help string for configuring this trainable's resources. Args: config (dict): The Trainer's config dict. """ return "" def get_current_ip(self): self._local_ip = ray.services.get_node_ip_address() return self._local_ip def train(self): """Runs one logical iteration of training. Calls ``step()`` internally. Subclasses should override ``step()`` instead to return results. This method automatically fills the following fields in the result: `done` (bool): training is terminated. Filled only if not provided. `time_this_iter_s` (float): Time in seconds this iteration took to run. This may be overriden in order to override the system-computed time difference. `time_total_s` (float): Accumulated time in seconds for this entire experiment. `experiment_id` (str): Unique string identifier for this experiment. This id is preserved across checkpoint / restore calls. `training_iteration` (int): The index of this training iteration, e.g. call to train(). This is incremented after `step()` is called. `pid` (str): The pid of the training process. `date` (str): A formatted date of when the result was processed. `timestamp` (str): A UNIX timestamp of when the result was processed. `hostname` (str): Hostname of the machine hosting the training process. `node_ip` (str): Node ip of the machine hosting the training process. Returns: A dict that describes training progress. """ start = time.time() result = self.step() assert isinstance(result, dict), "step() needs to return a dict." # We do not modify internal state nor update this result if duplicate. if RESULT_DUPLICATE in result: return result result = result.copy() self._iteration += 1 self._iterations_since_restore += 1 if result.get(TIME_THIS_ITER_S) is not None: time_this_iter = result[TIME_THIS_ITER_S] else: time_this_iter = time.time() - start self._time_total += time_this_iter self._time_since_restore += time_this_iter result.setdefault(DONE, False) # self._timesteps_total should only be tracked if increments provided if result.get(TIMESTEPS_THIS_ITER) is not None: if self._timesteps_total is None: self._timesteps_total = 0 self._timesteps_total += result[TIMESTEPS_THIS_ITER] self._timesteps_since_restore += result[TIMESTEPS_THIS_ITER] # self._episodes_total should only be tracked if increments provided if result.get(EPISODES_THIS_ITER) is not None: if self._episodes_total is None: self._episodes_total = 0 self._episodes_total += result[EPISODES_THIS_ITER] # self._timesteps_total should not override user-provided total result.setdefault(TIMESTEPS_TOTAL, self._timesteps_total) result.setdefault(EPISODES_TOTAL, self._episodes_total) result.setdefault(TRAINING_ITERATION, self._iteration) # Provides auto-filled neg_mean_loss for avoiding regressions if result.get("mean_loss"): result.setdefault("neg_mean_loss", -result["mean_loss"]) now = datetime.today() result.update( experiment_id=self._experiment_id, date=now.strftime("%Y-%m-%d_%H-%M-%S"), timestamp=int(time.mktime(now.timetuple())), time_this_iter_s=time_this_iter, time_total_s=self._time_total, pid=os.getpid(), hostname=platform.node(), node_ip=self._local_ip, config=self.config, time_since_restore=self._time_since_restore, timesteps_since_restore=self._timesteps_since_restore, iterations_since_restore=self._iterations_since_restore) monitor_data = self._monitor.get_data() if monitor_data: result.update(monitor_data) self.log_result(result) if self._stdout_context: self._stdout_stream.flush() if self._stderr_context: self._stderr_stream.flush() return result def get_state(self): return { "experiment_id": self._experiment_id, "iteration": self._iteration, "timesteps_total": self._timesteps_total, "time_total": self._time_total, "episodes_total": self._episodes_total, "ray_version": ray.__version__, } def save(self, checkpoint_dir=None): """Saves the current model state to a checkpoint. Subclasses should override ``_save()`` instead to save state. This method dumps additional metadata alongside the saved path. Args: checkpoint_dir (str): Optional dir to place the checkpoint. Returns: str: Checkpoint path or prefix that may be passed to restore(). """ checkpoint_dir = TrainableUtil.make_checkpoint_dir( checkpoint_dir or self.logdir, index=self.iteration) checkpoint = self.save_checkpoint(checkpoint_dir) trainable_state = self.get_state() checkpoint_path = TrainableUtil.process_checkpoint( checkpoint, parent_dir=checkpoint_dir, trainable_state=trainable_state) return checkpoint_path def save_to_object(self): """Saves the current model state to a Python object. It also saves to disk but does not return the checkpoint path. Returns: Object holding checkpoint data. """ tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir) checkpoint_path = self.save(tmpdir) # Save all files in subtree and delete the tmpdir. obj = TrainableUtil.checkpoint_to_object(checkpoint_path) shutil.rmtree(tmpdir) return obj def restore(self, checkpoint_path): """Restores training state from a given model checkpoint. These checkpoints are returned from calls to save(). Subclasses should override ``_restore()`` instead to restore state. This method restores additional metadata saved with the checkpoint. """ with open(checkpoint_path + ".tune_metadata", "rb") as f: metadata = pickle.load(f) self._experiment_id = metadata["experiment_id"] self._iteration = metadata["iteration"] self._timesteps_total = metadata["timesteps_total"] self._time_total = metadata["time_total"] self._episodes_total = metadata["episodes_total"] saved_as_dict = metadata["saved_as_dict"] if saved_as_dict: with open(checkpoint_path, "rb") as loaded_state: checkpoint_dict = pickle.load(loaded_state) checkpoint_dict.update(tune_checkpoint_path=checkpoint_path) self.load_checkpoint(checkpoint_dict) else: self.load_checkpoint(checkpoint_path) self._time_since_restore = 0.0 self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = True logger.info("Restored on %s from checkpoint: %s", self.get_current_ip(), checkpoint_path) state = { "_iteration": self._iteration, "_timesteps_total": self._timesteps_total, "_time_total": self._time_total, "_episodes_total": self._episodes_total, } logger.info("Current state after restoring: %s", state) def restore_from_object(self, obj): """Restores training state from a checkpoint object. These checkpoints are returned from calls to save_to_object(). """ tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir) checkpoint_path = TrainableUtil.create_from_pickle(obj, tmpdir) self.restore(checkpoint_path) shutil.rmtree(tmpdir) def delete_checkpoint(self, checkpoint_path): """Deletes local copy of checkpoint. Args: checkpoint_path (str): Path to checkpoint. """ try: checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) except FileNotFoundError: # The checkpoint won't exist locally if the # trial was rescheduled to another worker. logger.debug("Checkpoint not found during garbage collection.") return if os.path.exists(checkpoint_dir): shutil.rmtree(checkpoint_dir) def export_model(self, export_formats, export_dir=None): """Exports model based on export_formats. Subclasses should override _export_model() to actually export model to local directory. Args: export_formats (Union[list,str]): Format or list of (str) formats that should be exported. export_dir (str): Optional dir to place the exported model. Defaults to self.logdir. Returns: A dict that maps ExportFormats to successfully exported models. """ if isinstance(export_formats, str): export_formats = [export_formats] export_dir = export_dir or self.logdir return self._export_model(export_formats, export_dir) def reset(self, new_config, logger_creator=None): """Resets trial for use with new config. Subclasses should override reset_config() to actually reset actor behavior for the new config.""" self.config = new_config self._result_logger.flush() self._result_logger.close() self._create_logger(new_config.copy(), logger_creator) stdout_file = new_config.pop(STDOUT_FILE, None) stderr_file = new_config.pop(STDERR_FILE, None) self._close_logfiles() self._open_logfiles(stdout_file, stderr_file) return self.reset_config(new_config) def reset_config(self, new_config): """Resets configuration without restarting the trial. This method is optional, but can be implemented to speed up algorithms such as PBT, and to allow performance optimizations such as running experiments with reuse_actors=True. Args: new_config (dict): Updated hyperparameter configuration for the trainable. Returns: True if reset was successful else False. """ return False def _create_logger(self, config, logger_creator=None): """Create logger from logger creator. Sets _logdir and _result_logger. """ if logger_creator: self._result_logger = logger_creator(config) self._logdir = self._result_logger.logdir else: logdir_prefix = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") ray.utils.try_to_create_directory(DEFAULT_RESULTS_DIR) self._logdir = tempfile.mkdtemp( prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR) self._result_logger = UnifiedLogger( config, self._logdir, loggers=None) def _open_logfiles(self, stdout_file, stderr_file): """Create loggers. Open stdout and stderr logfiles.""" if stdout_file: stdout_path = os.path.expanduser( os.path.join(self._logdir, stdout_file)) self._stdout_fp = open(stdout_path, "a+") self._stdout_stream = Tee(sys.stdout, self._stdout_fp) self._stdout_context = redirect_stdout(self._stdout_stream) self._stdout_context.__enter__() if stderr_file: stderr_path = os.path.expanduser( os.path.join(self._logdir, stderr_file)) self._stderr_fp = open(stderr_path, "a+") self._stderr_stream = Tee(sys.stderr, self._stderr_fp) self._stderr_context = redirect_stderr(self._stderr_stream) self._stderr_context.__enter__() # Add logging handler to root ray logger formatter = logging.Formatter("[%(levelname)s %(asctime)s] " "%(filename)s: %(lineno)d " "%(message)s") self._stderr_logging_handler = logging.StreamHandler( self._stderr_fp) self._stderr_logging_handler.setFormatter(formatter) ray.logger.addHandler(self._stderr_logging_handler) def _close_logfiles(self): """Close stdout and stderr logfiles.""" if self._stderr_logging_handler: ray.logger.removeHandler(self._stderr_logging_handler) if self._stdout_context: self._stdout_stream.flush() self._stdout_context.__exit__(None, None, None) self._stdout_fp.close() self._stdout_context = None if self._stderr_context: self._stderr_stream.flush() self._stderr_context.__exit__(None, None, None) self._stderr_fp.close() self._stderr_context = None def stop(self): """Releases all resources used by this trainable. Calls ``Trainable.cleanup`` internally. Subclasses should override ``Trainable.cleanup`` for custom cleanup procedures. """ self._result_logger.flush() self._result_logger.close() self.cleanup() self._close_logfiles() @property def logdir(self): """Directory of the results and checkpoints for this Trainable. Tune will automatically sync this folder with the driver if execution is distributed. Note that the current working directory will also be changed to this. """ return os.path.join(self._logdir, "") @property def trial_name(self): """Trial name for the corresponding trial of this Trainable. This is not set if not using Tune. .. code-block:: python name = self.trial_name """ if self._trial_info: return self._trial_info.trial_name else: return "default" @property def trial_id(self): """Trial ID for the corresponding trial of this Trainable. This is not set if not using Tune. .. code-block:: python trial_id = self.trial_id """ if self._trial_info: return self._trial_info.trial_id else: return "default" @property def iteration(self): """Current training iteration. This value is automatically incremented every time `train()` is called and is automatically inserted into the training result dict. """ return self._iteration @property def training_iteration(self): """Current training iteration (same as `self.iteration`). This value is automatically incremented every time `train()` is called and is automatically inserted into the training result dict. """ return self._iteration def get_config(self): """Returns configuration passed in by Tune.""" return self.config def step(self): """Subclasses should override this to implement train(). The return value will be automatically passed to the loggers. Users can also return `tune.result.DONE` or `tune.result.SHOULD_CHECKPOINT` as a key to manually trigger termination or checkpointing of this trial. Note that manual checkpointing only works when subclassing Trainables. .. versionadded:: 0.8.7 Returns: A dict that describes training progress. """ result = self._train() if self._is_overriden("_train") and log_once("_train"): logger.warning( "Trainable._train is deprecated and will be removed in " "a future version of Ray. Override Trainable.step instead.") return result def _train(self): """This method is deprecated. Override 'Trainable.step' instead. .. versionchanged:: 0.8.7 """ raise NotImplementedError def save_checkpoint(self, tmp_checkpoint_dir): """Subclasses should override this to implement ``save()``. Warning: Do not rely on absolute paths in the implementation of ``Trainable.save_checkpoint`` and ``Trainable.load_checkpoint``. Use ``validate_save_restore`` to catch ``Trainable.save_checkpoint``/ ``Trainable.load_checkpoint`` errors before execution. >>> from ray.tune.utils import validate_save_restore >>> validate_save_restore(MyTrainableClass) >>> validate_save_restore(MyTrainableClass, use_object_store=True) .. versionadded:: 0.8.7 Args: tmp_checkpoint_dir (str): The directory where the checkpoint file must be stored. In a Tune run, if the trial is paused, the provided path may be temporary and moved. Returns: A dict or string. If string, the return value is expected to be prefixed by `tmp_checkpoint_dir`. If dict, the return value will be automatically serialized by Tune and passed to ``Trainable.load_checkpoint()``. Examples: >>> print(trainable1.save_checkpoint("/tmp/checkpoint_1")) "/tmp/checkpoint_1/my_checkpoint_file" >>> print(trainable2.save_checkpoint("/tmp/checkpoint_2")) {"some": "data"} >>> trainable.save_checkpoint("/tmp/bad_example") "/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error. """ checkpoint = self._save(tmp_checkpoint_dir) if self._is_overriden("_save") and log_once("_save"): logger.warning( "Trainable._save is deprecated and will be removed in a " "future version of Ray. Override " "Trainable.save_checkpoint instead.") return checkpoint def _save(self, tmp_checkpoint_dir): """This method is deprecated. Override 'save_checkpoint' instead. .. versionchanged:: 0.8.7 """ raise NotImplementedError def load_checkpoint(self, checkpoint): """Subclasses should override this to implement restore(). Warning: In this method, do not rely on absolute paths. The absolute path of the checkpoint_dir used in ``Trainable.save_checkpoint`` may be changed. If ``Trainable.save_checkpoint`` returned a prefixed string, the prefix of the checkpoint string returned by ``Trainable.save_checkpoint`` may be changed. This is because trial pausing depends on temporary directories. The directory structure under the checkpoint_dir provided to ``Trainable.save_checkpoint`` is preserved. See the example below. .. code-block:: python class Example(Trainable): def save_checkpoint(self, checkpoint_path): print(checkpoint_path) return os.path.join(checkpoint_path, "my/check/point") def load_checkpoint(self, checkpoint): print(checkpoint) >>> trainer = Example() >>> obj = trainer.save_to_object() # This is used when PAUSED. <logdir>/tmpc8k_c_6hsave_to_object/checkpoint_0/my/check/point >>> trainer.restore_from_object(obj) # Note the different prefix. <logdir>/tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point .. versionadded:: 0.8.7 Args: checkpoint (str|dict): If dict, the return value is as returned by `save_checkpoint`. If a string, then it is a checkpoint path that may have a different prefix than that returned by `save_checkpoint`. The directory structure underneath the `checkpoint_dir` `save_checkpoint` is preserved. """ self._restore(checkpoint) if self._is_overriden("_restore") and log_once("_restore"): logger.warning( "Trainable._restore is deprecated and will be removed in a " "future version of Ray. Override Trainable.load_checkpoint " "instead.") def _restore(self, checkpoint): """This method is deprecated. Override 'load_checkpoint' instead. .. versionchanged:: 0.8.7 """ raise NotImplementedError def setup(self, config): """Subclasses should override this for custom initialization. .. versionadded:: 0.8.7 Args: config (dict): Hyperparameters and other configs given. Copy of `self.config`. """ self._setup(config) if self._is_overriden("_setup") and log_once("_setup"): logger.warning( "Trainable._setup is deprecated and will be removed in " "a future version of Ray. Override Trainable.setup instead.") def _setup(self, config): """This method is deprecated. Override 'setup' instead. .. versionchanged:: 0.8.7 """ pass def log_result(self, result): """Subclasses can optionally override this to customize logging. The logging here is done on the worker process rather than the driver. You may want to turn off driver logging via the ``loggers`` parameter in ``tune.run`` when overriding this function. .. versionadded:: 0.8.7 Args: result (dict): Training result returned by step(). """ self._log_result(result) if self._is_overriden("_log_result") and log_once("_log_result"): logger.warning( "Trainable._log_result is deprecated and will be removed in " "a future version of Ray. Override " "Trainable.log_result instead.") def _log_result(self, result): """This method is deprecated. Override 'log_result' instead. .. versionchanged:: 0.8.7 """ self._result_logger.on_result(result) def cleanup(self): """Subclasses should override this for any cleanup on stop. If any Ray actors are launched in the Trainable (i.e., with a RLlib trainer), be sure to kill the Ray actor process here. You can kill a Ray actor by calling `actor.__ray_terminate__.remote()` on the actor. .. versionadded:: 0.8.7 """ self._stop() if self._is_overriden("_stop") and log_once("trainable.cleanup"): logger.warning( "Trainable._stop is deprecated and will be removed in " "a future version of Ray. Override Trainable.cleanup instead.") def _stop(self): """This method is deprecated. Override 'cleanup' instead. .. versionchanged:: 0.8.7 """ pass def _export_model(self, export_formats, export_dir): """Subclasses should override this to export model. Args: export_formats (list): List of formats that should be exported. export_dir (str): Directory to place exported models. Return: A dict that maps ExportFormats to successfully exported models. """ return {} def _is_overriden(self, key): return getattr(self, key).__code__ != getattr(Trainable, key).__code__
class Trial(object): """A trial object holds the state for one model training run. Trials are themselves managed by the TrialRunner class, which implements the event loop for submitting trial runs to a Ray cluster. Trials start in the PENDING state, and transition to RUNNING once started. On error it transitions to ERROR, otherwise TERMINATED on success. """ PENDING = "PENDING" RUNNING = "RUNNING" PAUSED = "PAUSED" TERMINATED = "TERMINATED" ERROR = "ERROR" def __init__(self, trainable_name, config=None, trial_id=None, local_dir=DEFAULT_RESULTS_DIR, experiment_tag="", resources=None, stopping_criterion=None, checkpoint_freq=0, checkpoint_at_end=False, export_formats=None, restore_path=None, upload_dir=None, trial_name_creator=None, loggers=None, sync_function=None, max_failures=0): """Initialize a new trial. The args here take the same meaning as the command line flags defined in ray.tune.config_parser. """ Trial._registration_check(trainable_name) # Trial config self.trainable_name = trainable_name self.config = config or {} self.local_dir = os.path.expanduser(local_dir) self.experiment_tag = experiment_tag self.resources = ( resources or self._get_trainable_cls().default_resource_request(self.config)) self.stopping_criterion = stopping_criterion or {} self.upload_dir = upload_dir self.loggers = loggers self.sync_function = sync_function validate_sync_function(sync_function) self.verbose = True self.max_failures = max_failures # Local trial state that is updated during the run self.last_result = {} self.last_update_time = -float("inf") self.checkpoint_freq = checkpoint_freq self.checkpoint_at_end = checkpoint_at_end self._checkpoint = Checkpoint( storage=Checkpoint.DISK, value=restore_path) self.export_formats = export_formats self.status = Trial.PENDING self.logdir = None self.runner = None self.result_logger = None self.last_debug = 0 self.trial_id = Trial.generate_id() if trial_id is None else trial_id self.error_file = None self.num_failures = 0 self.custom_trial_name = None # AutoML fields self.results = None self.best_result = None self.param_config = None self.extra_arg = None self._nonjson_fields = [ "_checkpoint", "config", "loggers", "sync_function", "last_result", "results", "best_result", "param_config", "extra_arg", ] if trial_name_creator: self.custom_trial_name = trial_name_creator(self) @classmethod def _registration_check(cls, trainable_name): if not has_trainable(trainable_name): # Make sure rllib agents are registered from ray import rllib # noqa: F401 if not has_trainable(trainable_name): raise TuneError("Unknown trainable: " + trainable_name) @classmethod def generate_id(cls): return binary_to_hex(_random_string())[:8] def init_logger(self): """Init logger.""" if not self.result_logger: if not os.path.exists(self.local_dir): os.makedirs(self.local_dir) if not self.logdir: self.logdir = tempfile.mkdtemp( prefix="{}_{}".format( str(self)[:MAX_LEN_IDENTIFIER], date_str()), dir=self.local_dir) elif not os.path.exists(self.logdir): os.makedirs(self.logdir) self.result_logger = UnifiedLogger( self.config, self.logdir, upload_uri=self.upload_dir, loggers=self.loggers, sync_function=self.sync_function) def update_resources(self, cpu, gpu, **kwargs): """EXPERIMENTAL: Updates the resource requirements. Should only be called when the trial is not running. Raises: ValueError if trial status is running. """ if self.status is Trial.RUNNING: raise ValueError("Cannot update resources while Trial is running.") self.resources = Resources(cpu, gpu, **kwargs) def sync_logger_to_new_location(self, worker_ip): """Updates the logger location. Also pushes logdir to worker_ip, allowing for cross-node recovery. """ if self.result_logger: self.result_logger.sync_results_to_new_location(worker_ip) def close_logger(self): """Close logger.""" if self.result_logger: self.result_logger.close() self.result_logger = None def write_error_log(self, error_msg): if error_msg and self.logdir: self.num_failures += 1 # may be moved to outer scope? error_file = os.path.join(self.logdir, "error_{}.txt".format(date_str())) with open(error_file, "w") as f: f.write(error_msg) self.error_file = error_file def should_stop(self, result): """Whether the given result meets this trial's stopping criteria.""" if result.get(DONE): return True for criteria, stop_value in self.stopping_criterion.items(): if criteria not in result: raise TuneError( "Stopping criteria {} not provided in result {}.".format( criteria, result)) if result[criteria] >= stop_value: return True return False def should_checkpoint(self): """Whether this trial is due for checkpointing.""" result = self.last_result or {} if result.get(DONE) and self.checkpoint_at_end: return True if self.checkpoint_freq: return result.get(TRAINING_ITERATION, 0) % self.checkpoint_freq == 0 else: return False def progress_string(self): """Returns a progress message for printing out to the console.""" if not self.last_result: return self._status_string() def location_string(hostname, pid): if hostname == os.uname()[1]: return 'pid={}'.format(pid) else: return '{} pid={}'.format(hostname, pid) pieces = [ '{}'.format(self._status_string()), '[{}]'.format( self.resources.summary_string()), '[{}]'.format( location_string( self.last_result.get(HOSTNAME), self.last_result.get(PID))), '{} s'.format( int(self.last_result.get(TIME_TOTAL_S))) ] if self.last_result.get(TRAINING_ITERATION) is not None: pieces.append('{} iter'.format( self.last_result[TRAINING_ITERATION])) if self.last_result.get(TIMESTEPS_TOTAL) is not None: pieces.append('{} ts'.format(self.last_result[TIMESTEPS_TOTAL])) if self.last_result.get(EPISODE_REWARD_MEAN) is not None: pieces.append('{} rew'.format( format(self.last_result[EPISODE_REWARD_MEAN], '.3g'))) if self.last_result.get(MEAN_LOSS) is not None: pieces.append('{} loss'.format( format(self.last_result[MEAN_LOSS], '.3g'))) if self.last_result.get(MEAN_ACCURACY) is not None: pieces.append('{} acc'.format( format(self.last_result[MEAN_ACCURACY], '.3g'))) return ', '.join(pieces) def _status_string(self): return "{}{}".format( self.status, ", {} failures: {}".format(self.num_failures, self.error_file) if self.error_file else "") def has_checkpoint(self): return self._checkpoint.value is not None def clear_checkpoint(self): self._checkpoint.value = None def should_recover(self): """Returns whether the trial qualifies for restoring. This is if a checkpoint frequency is set and has not failed more than max_failures. This may return true even when there may not yet be a checkpoint. """ return (self.checkpoint_freq > 0 and (self.num_failures < self.max_failures or self.max_failures < 0)) def update_last_result(self, result, terminate=False): if terminate: result.update(done=True) if self.verbose and (terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL): print("Result for {}:".format(self)) print(" {}".format(pretty_print(result).replace("\n", "\n "))) self.last_debug = time.time() self.last_result = result self.last_update_time = time.time() self.result_logger.on_result(self.last_result) def _get_trainable_cls(self): return ray.tune.registry._global_registry.get( ray.tune.registry.TRAINABLE_CLASS, self.trainable_name) def set_verbose(self, verbose): self.verbose = verbose def is_finished(self): return self.status in [Trial.TERMINATED, Trial.ERROR] def __repr__(self): return str(self) def __str__(self): """Combines ``env`` with ``trainable_name`` and ``experiment_tag``. Can be overriden with a custom string creator. """ if self.custom_trial_name: return self.custom_trial_name if "env" in self.config: env = self.config["env"] if isinstance(env, type): env = env.__name__ identifier = "{}_{}".format(self.trainable_name, env) else: identifier = self.trainable_name if self.experiment_tag: identifier += "_" + self.experiment_tag return identifier.replace("/", "_") def __getstate__(self): """Memento generator for Trial. Sets RUNNING trials to PENDING, and flushes the result logger. Note this can only occur if the trial holds a DISK checkpoint. """ assert self._checkpoint.storage == Checkpoint.DISK, ( "Checkpoint must not be in-memory.") state = self.__dict__.copy() state["resources"] = resources_to_json(self.resources) for key in self._nonjson_fields: state[key] = binary_to_hex(cloudpickle.dumps(state.get(key))) state["runner"] = None state["result_logger"] = None if self.result_logger: self.result_logger.flush() state["__logger_started__"] = True else: state["__logger_started__"] = False return copy.deepcopy(state) def __setstate__(self, state): logger_started = state.pop("__logger_started__") state["resources"] = json_to_resources(state["resources"]) if state["status"] == Trial.RUNNING: state["status"] = Trial.PENDING for key in self._nonjson_fields: state[key] = cloudpickle.loads(hex_to_binary(state[key])) self.__dict__.update(state) Trial._registration_check(self.trainable_name) if logger_started: self.init_logger()
class Trial: """A trial object holds the state for one model training run. Trials are themselves managed by the TrialRunner class, which implements the event loop for submitting trial runs to a Ray cluster. Trials start in the PENDING state, and transition to RUNNING once started. On error it transitions to ERROR, otherwise TERMINATED on success. """ PENDING = "PENDING" RUNNING = "RUNNING" PAUSED = "PAUSED" TERMINATED = "TERMINATED" ERROR = "ERROR" def __init__(self, trainable_name, config=None, trial_id=None, local_dir=DEFAULT_RESULTS_DIR, evaluated_params=None, experiment_tag="", resources=None, stopping_criterion=None, remote_checkpoint_dir=None, checkpoint_freq=0, checkpoint_at_end=False, sync_on_checkpoint=True, keep_checkpoints_num=None, checkpoint_score_attr=TRAINING_ITERATION, export_formats=None, restore_path=None, trial_name_creator=None, loggers=None, sync_to_driver_fn=None, max_failures=0): """Initialize a new trial. The args here take the same meaning as the command line flags defined in ray.tune.config_parser. """ validate_trainable(trainable_name) # Trial config self.trainable_name = trainable_name self.trial_id = Trial.generate_id() if trial_id is None else trial_id self.config = config or {} self.local_dir = local_dir # This remains unexpanded for syncing. #: Parameters that Tune varies across searches. self.evaluated_params = evaluated_params or {} self.experiment_tag = experiment_tag trainable_cls = self.get_trainable_cls() if trainable_cls and hasattr(trainable_cls, "default_resource_request"): default_resources = trainable_cls.default_resource_request( self.config) if default_resources: if resources: raise ValueError( "Resources for {} have been automatically set to {} " "by its `default_resource_request()` method. Please " "clear the `resources_per_trial` option.".format( trainable_cls, default_resources)) resources = default_resources self.location = Location() self.resources = resources or Resources(cpu=1, gpu=0) self.stopping_criterion = stopping_criterion or {} self.loggers = loggers self.sync_to_driver_fn = sync_to_driver_fn self.verbose = True self.max_failures = max_failures # Local trial state that is updated during the run self.last_result = {} self.last_update_time = -float("inf") # stores in memory max/min/last result for each metric by trial self.metric_analysis = {} self.export_formats = export_formats self.status = Trial.PENDING self.start_time = None self.logdir = None self.runner = None self.result_logger = None self.last_debug = 0 self.error_file = None self.error_msg = None self.custom_trial_name = None # Checkpointing fields if remote_checkpoint_dir: self.remote_checkpoint_dir_prefix = remote_checkpoint_dir else: self.remote_checkpoint_dir_prefix = None self.checkpoint_freq = checkpoint_freq self.checkpoint_at_end = checkpoint_at_end self.sync_on_checkpoint = sync_on_checkpoint newest_checkpoint = Checkpoint(Checkpoint.PERSISTENT, restore_path) self.checkpoint_manager = CheckpointManager( keep_checkpoints_num, checkpoint_score_attr, checkpoint_deleter(str(self), self.runner)) self.checkpoint_manager.newest_checkpoint = newest_checkpoint # Restoration fields self.restoring_from = None self.num_failures = 0 self.num_consecutive_start_attempts = 0 # AutoML fields self.results = None self.best_result = None self.param_config = None self.extra_arg = None self._nonjson_fields = [ "loggers", "sync_to_driver_fn", "results", "best_result", "param_config", "extra_arg", ] if trial_name_creator: self.custom_trial_name = trial_name_creator(self) @property def node_ip(self): return self.location.hostname @property def checkpoint(self): return self.checkpoint_manager.newest_checkpoint @classmethod def generate_id(cls): return str(uuid.uuid1().hex)[:8] @property def remote_checkpoint_dir(self): assert self.logdir, "Trial {}: logdir not initialized.".format(self) if not self.remote_checkpoint_dir_prefix: return None logdir_name = os.path.basename(self.logdir) return os.path.join(self.remote_checkpoint_dir_prefix, logdir_name) @classmethod def create_logdir(cls, identifier, local_dir): local_dir = os.path.expanduser(local_dir) os.makedirs(local_dir, exist_ok=True) return tempfile.mkdtemp(prefix="{}_{}".format( identifier[:MAX_LEN_IDENTIFIER], date_str()), dir=local_dir) def init_logger(self): """Init logger.""" if not self.result_logger: if not self.logdir: self.logdir = Trial.create_logdir(str(self), self.local_dir) else: os.makedirs(self.logdir, exist_ok=True) self.result_logger = UnifiedLogger( self.config, self.logdir, trial=self, loggers=self.loggers, sync_function=self.sync_to_driver_fn) def update_resources(self, cpu, gpu, **kwargs): """EXPERIMENTAL: Updates the resource requirements. Should only be called when the trial is not running. Raises: ValueError if trial status is running. """ if self.status is Trial.RUNNING: raise ValueError("Cannot update resources while Trial is running.") self.resources = Resources(cpu, gpu, **kwargs) def set_runner(self, runner): self.runner = runner self.checkpoint_manager.delete = checkpoint_deleter(str(self), runner) def set_location(self, location): """Sets the location of the trial.""" self.location = location def set_status(self, status): """Sets the status of the trial.""" self.status = status if status == Trial.RUNNING: if self.start_time is None: self.start_time = time.time() def close_logger(self): """Closes logger.""" if self.result_logger: self.result_logger.close() self.result_logger = None def write_error_log(self, error_msg): if error_msg and self.logdir: self.num_failures += 1 self.error_file = os.path.join(self.logdir, "error.txt") with open(self.error_file, "a+") as f: f.write("Failure # {} (occurred at {})\n".format( self.num_failures, date_str())) f.write(error_msg + "\n") self.error_msg = error_msg def should_stop(self, result): """Whether the given result meets this trial's stopping criteria.""" if result.get(DONE): return True if callable(self.stopping_criterion): return self.stopping_criterion(self.trial_id, result) for criteria, stop_value in self.stopping_criterion.items(): if criteria not in result: raise TuneError( "Stopping criteria {} not provided in result {}.".format( criteria, result)) elif isinstance(criteria, dict): raise ValueError( "Stopping criteria is now flattened by default. " "Use forward slashes to nest values `key1/key2/key3`.") elif result[criteria] >= stop_value: return True return False def should_checkpoint(self): """Whether this trial is due for checkpointing.""" result = self.last_result or {} if result.get(DONE) and self.checkpoint_at_end: return True return (self.checkpoint_freq and result.get(TRAINING_ITERATION, 0) % self.checkpoint_freq == 0) def has_checkpoint(self): return self.checkpoint.value is not None def clear_checkpoint(self): self.checkpoint.value = None self.restoring_from = None def on_checkpoint(self, checkpoint): """Hook for handling checkpoints taken by the Trainable. Args: checkpoint (Checkpoint): Checkpoint taken. """ if checkpoint.storage == Checkpoint.MEMORY: # TODO(ujvl): Handle this separately to avoid restoration failure. self.checkpoint_manager.on_checkpoint(checkpoint) return if self.sync_on_checkpoint: try: # Wait for any other syncs to finish. We need to sync again # after this to handle checkpoints taken mid-sync. self.result_logger.wait() except TuneError as e: # Errors occurring during this wait are not fatal for this # checkpoint, so it should just be logged. logger.error( "Trial %s: An error occurred during the " "checkpoint pre-sync wait.", str(e)) # Force sync down and wait before tracking the new checkpoint. try: if self.result_logger.sync_down(): self.result_logger.wait() else: logger.error( "Trial %s: Checkpoint sync skipped. " "This should not happen.", self) except TuneError as e: if issubclass(self.get_trainable_cls(), DurableTrainable): # Even though rsync failed the trainable can restore # from remote durable storage. logger.error("Trial %s: Sync error - %s", self, str(e)) else: # If the trainable didn't have remote storage to upload # to then this checkpoint may have been lost, so we # shouldn't track it with the checkpoint_manager. raise e if not issubclass(self.get_trainable_cls(), DurableTrainable): if not os.path.exists(checkpoint.value): raise TuneError("Trial {}: Checkpoint path {} not " "found after successful sync down.".format( self, checkpoint.value)) self.checkpoint_manager.on_checkpoint(checkpoint) def on_restore(self): """Handles restoration completion.""" assert self.is_restoring self.last_result = self.restoring_from.result self.restoring_from = None def should_recover(self): """Returns whether the trial qualifies for retrying. This is if the trial has not failed more than max_failures. Note this may return true even when there is no checkpoint, either because `self.checkpoint_freq` is `0` or because the trial failed before a checkpoint has been made. """ return self.num_failures < self.max_failures or self.max_failures < 0 def update_last_result(self, result, terminate=False): result.update(trial_id=self.trial_id, done=terminate) if self.experiment_tag: result.update(experiment_tag=self.experiment_tag) if self.verbose and (terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL): print("Result for {}:".format(self)) print(" {}".format(pretty_print(result).replace("\n", "\n "))) self.last_debug = time.time() self.set_location(Location(result.get("node_ip"), result.get("pid"))) self.last_result = result self.last_update_time = time.time() self.result_logger.on_result(self.last_result) for metric, value in flatten_dict(result).items(): if isinstance(value, Number): if metric not in self.metric_analysis: self.metric_analysis[metric] = { "max": value, "min": value, "last": value } else: self.metric_analysis[metric]["max"] = max( value, self.metric_analysis[metric]["max"]) self.metric_analysis[metric]["min"] = min( value, self.metric_analysis[metric]["min"]) self.metric_analysis[metric]["last"] = value def get_trainable_cls(self): return get_trainable_cls(self.trainable_name) def set_verbose(self, verbose): self.verbose = verbose def is_finished(self): return self.status in [Trial.ERROR, Trial.TERMINATED] @property def is_restoring(self): return self.restoring_from is not None def __repr__(self): return str(self) def __str__(self): """Combines ``env`` with ``trainable_name`` and ``trial_id``. Can be overridden with a custom string creator. """ if self.custom_trial_name: return self.custom_trial_name if "env" in self.config: env = self.config["env"] if isinstance(env, type): env = env.__name__ identifier = "{}_{}".format(self.trainable_name, env) else: identifier = self.trainable_name identifier += "_" + self.trial_id return identifier.replace("/", "_") def __getstate__(self): """Memento generator for Trial. Sets RUNNING trials to PENDING, and flushes the result logger. Note this can only occur if the trial holds a PERSISTENT checkpoint. """ assert self.checkpoint.storage == Checkpoint.PERSISTENT, ( "Checkpoint must not be in-memory.") state = self.__dict__.copy() state["resources"] = resources_to_json(self.resources) for key in self._nonjson_fields: state[key] = binary_to_hex(cloudpickle.dumps(state.get(key))) state["runner"] = None state["result_logger"] = None if self.result_logger: self.result_logger.flush(sync_down=False) state["__logger_started__"] = True else: state["__logger_started__"] = False return copy.deepcopy(state) def __setstate__(self, state): logger_started = state.pop("__logger_started__") state["resources"] = json_to_resources(state["resources"]) if state["status"] == Trial.RUNNING: state["status"] = Trial.PENDING for key in self._nonjson_fields: state[key] = cloudpickle.loads(hex_to_binary(state[key])) self.__dict__.update(state) validate_trainable(self.trainable_name) if logger_started: self.init_logger()
class Trainable: """Abstract class for trainable models, functions, etc. A call to ``train()`` on a trainable will execute one logical iteration of training. As a rule of thumb, the execution time of one train call should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes). Calling ``save()`` should save the training state of a trainable to disk, and ``restore(path)`` should restore a trainable to the given state. Generally you only need to implement ``setup``, ``step``, ``save_checkpoint``, and ``load_checkpoint`` when subclassing Trainable. Other implementation methods that may be helpful to override are ``log_result``, ``reset_config``, ``cleanup``, and ``_export_model``. When using Tune, Tune will convert this class into a Ray actor, which runs on a separate process. Tune will also change the current working directory of this process to ``self.logdir``. This is designed so that different trials that run on the same physical node won't accidently write to the same location and overstep each other. If you want to know the orginal working directory path on the driver node, you can do so through env variable "TUNE_ORIG_WORKING_DIR". It is advised that you access this path for read only purposes and you need to make sure that the path exists on the remote nodes. This class supports checkpointing to and restoring from remote storage. """ _sync_function_tpl = None def __init__( self, config: Dict[str, Any] = None, logger_creator: Callable[[Dict[str, Any]], Logger] = None, remote_checkpoint_dir: Optional[str] = None, sync_function_tpl: Optional[str] = None, ): """Initialize an Trainable. Sets up logging and points ``self.logdir`` to a directory in which training outputs should be placed. Subclasses should prefer defining ``setup()`` instead of overriding ``__init__()`` directly. Args: config (dict): Trainable-specific configuration data. By default will be saved as ``self.config``. logger_creator (func): Function that creates a ray.tune.Logger object. If unspecified, a default logger is created. remote_checkpoint_dir (str): Upload directory (S3 or GS path). This is **per trial** directory, which is different from **per checkpoint** directory. sync_function_tpl (str): Sync function template to use. Defaults to `cls._sync_function` (which defaults to `None`). """ self._experiment_id = uuid.uuid4().hex self.config = config or {} trial_info = self.config.pop(TRIAL_INFO, None) if self.is_actor(): disable_ipython() self._result_logger = self._logdir = None self._create_logger(self.config, logger_creator) self._stdout_context = self._stdout_fp = self._stdout_stream = None self._stderr_context = self._stderr_fp = self._stderr_stream = None self._stderr_logging_handler = None stdout_file = self.config.pop(STDOUT_FILE, None) stderr_file = self.config.pop(STDERR_FILE, None) self._open_logfiles(stdout_file, stderr_file) self._iteration = 0 self._time_total = 0.0 self._timesteps_total = None self._episodes_total = None self._time_since_restore = 0.0 self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = False self._trial_info = trial_info self._stdout_file = stdout_file self._stderr_file = stderr_file start_time = time.time() self._local_ip = self.get_current_ip() self.setup(copy.deepcopy(self.config)) setup_time = time.time() - start_time if setup_time > SETUP_TIME_THRESHOLD: logger.info("Trainable.setup took {:.3f} seconds. If your " "trainable is slow to initialize, consider setting " "reuse_actors=True to reduce actor creation " "overheads.".format(setup_time)) log_sys_usage = self.config.get("log_sys_usage", False) self._start_time = start_time self._warmup_time = None self._monitor = UtilMonitor(start=log_sys_usage) self.remote_checkpoint_dir = remote_checkpoint_dir self.sync_function_tpl = sync_function_tpl or self._sync_function_tpl self.storage_client = None if self.uses_cloud_checkpointing: self.storage_client = self._create_storage_client() @property def uses_cloud_checkpointing(self): return bool(self.remote_checkpoint_dir) def _create_storage_client(self): """Returns a storage client.""" return get_sync_client( self.sync_function_tpl) or get_cloud_sync_client( self.remote_checkpoint_dir) def _storage_path(self, local_path): """Converts a `local_path` to be based off of `self.remote_checkpoint_dir`.""" rel_local_path = os.path.relpath(local_path, self.logdir) return os.path.join(self.remote_checkpoint_dir, rel_local_path) @classmethod def default_resource_request( cls, config: Dict[str, Any]) -> Union[Resources, PlacementGroupFactory]: """Provides a static resource requirement for the given configuration. This can be overridden by sub-classes to set the correct trial resource allocation, so the user does not need to. .. code-block:: python @classmethod def default_resource_request(cls, config): return PlacementGroupFactory([{"CPU": 1}, {"CPU": 1}]]) Args: config[Dict[str, Any]]: The Trainable's config dict. Returns: Union[Resources, PlacementGroupFactory]: A Resources object or PlacementGroupFactory consumed by Tune for queueing. """ return None @classmethod def resource_help(cls, config): """Returns a help string for configuring this trainable's resources. Args: config (dict): The Trainer's config dict. """ return "" def get_current_ip(self): self._local_ip = ray.util.get_node_ip_address() return self._local_ip def get_auto_filled_metrics( self, now: Optional[datetime] = None, time_this_iter: Optional[float] = None, debug_metrics_only: bool = False, ) -> dict: """Return a dict with metrics auto-filled by the trainable. If ``debug_metrics_only`` is True, only metrics that don't require at least one iteration will be returned (``ray.tune.result.DEBUG_METRICS``). """ if now is None: now = datetime.today() autofilled = { TRIAL_ID: self.trial_id, "experiment_id": self._experiment_id, "date": now.strftime("%Y-%m-%d_%H-%M-%S"), "timestamp": int(time.mktime(now.timetuple())), TIME_THIS_ITER_S: time_this_iter, TIME_TOTAL_S: self._time_total, PID: os.getpid(), HOSTNAME: platform.node(), NODE_IP: self._local_ip, "config": self.config, "time_since_restore": self._time_since_restore, "timesteps_since_restore": self._timesteps_since_restore, "iterations_since_restore": self._iterations_since_restore, "warmup_time": self._warmup_time, } if debug_metrics_only: autofilled = { k: v for k, v in autofilled.items() if k in DEBUG_METRICS } return autofilled def is_actor(self): try: actor_id = ray.worker.global_worker.actor_id return actor_id != actor_id.nil() except Exception: # If global_worker is not instantiated, we're not in an actor return False def train_buffered(self, buffer_time_s: float, max_buffer_length: int = 1000): """Runs multiple iterations of training. Calls ``train()`` internally. Collects and combines multiple results. This function will run ``self.train()`` repeatedly until one of the following conditions is met: 1) the maximum buffer length is reached, 2) the maximum buffer time is reached, or 3) a checkpoint was created. Even if the maximum time is reached, it will always block until at least one result is received. Args: buffer_time_s (float): Maximum time to buffer. The next result received after this amount of time has passed will return the whole buffer. max_buffer_length (int): Maximum number of results to buffer. """ results = [] now = time.time() send_buffer_at = now + buffer_time_s while now < send_buffer_at or not results: # At least one result result = self.train() results.append(result) if result.get(DONE, False): # If the trial is done, return break elif result.get(SHOULD_CHECKPOINT, False): # If a checkpoint was created, return break elif result.get(RESULT_DUPLICATE): # If the function API trainable completed, return break elif len(results) >= max_buffer_length: # If the buffer is full, return break now = time.time() return results def train(self): """Runs one logical iteration of training. Calls ``step()`` internally. Subclasses should override ``step()`` instead to return results. This method automatically fills the following fields in the result: `done` (bool): training is terminated. Filled only if not provided. `time_this_iter_s` (float): Time in seconds this iteration took to run. This may be overridden in order to override the system-computed time difference. `time_total_s` (float): Accumulated time in seconds for this entire experiment. `experiment_id` (str): Unique string identifier for this experiment. This id is preserved across checkpoint / restore calls. `training_iteration` (int): The index of this training iteration, e.g. call to train(). This is incremented after `step()` is called. `pid` (str): The pid of the training process. `date` (str): A formatted date of when the result was processed. `timestamp` (str): A UNIX timestamp of when the result was processed. `hostname` (str): Hostname of the machine hosting the training process. `node_ip` (str): Node ip of the machine hosting the training process. Returns: A dict that describes training progress. """ if self._warmup_time is None: self._warmup_time = time.time() - self._start_time start = time.time() result = self.step() assert isinstance(result, dict), "step() needs to return a dict." # We do not modify internal state nor update this result if duplicate. if RESULT_DUPLICATE in result: return result result = result.copy() self._iteration += 1 self._iterations_since_restore += 1 if result.get(TIME_THIS_ITER_S) is not None: time_this_iter = result[TIME_THIS_ITER_S] else: time_this_iter = time.time() - start self._time_total += time_this_iter self._time_since_restore += time_this_iter result.setdefault(DONE, False) # self._timesteps_total should only be tracked if increments provided if result.get(TIMESTEPS_THIS_ITER) is not None: if self._timesteps_total is None: self._timesteps_total = 0 self._timesteps_total += result[TIMESTEPS_THIS_ITER] self._timesteps_since_restore += result[TIMESTEPS_THIS_ITER] # self._episodes_total should only be tracked if increments provided if result.get(EPISODES_THIS_ITER) is not None: if self._episodes_total is None: self._episodes_total = 0 self._episodes_total += result[EPISODES_THIS_ITER] # self._timesteps_total should not override user-provided total result.setdefault(TIMESTEPS_TOTAL, self._timesteps_total) result.setdefault(EPISODES_TOTAL, self._episodes_total) result.setdefault(TRAINING_ITERATION, self._iteration) # Provides auto-filled neg_mean_loss for avoiding regressions if result.get("mean_loss"): result.setdefault("neg_mean_loss", -result["mean_loss"]) now = datetime.today() result.update(self.get_auto_filled_metrics(now, time_this_iter)) monitor_data = self._monitor.get_data() if monitor_data: result.update(monitor_data) self.log_result(result) if self._stdout_context: self._stdout_stream.flush() if self._stderr_context: self._stderr_stream.flush() return result def get_state(self): return { "experiment_id": self._experiment_id, "iteration": self._iteration, "timesteps_total": self._timesteps_total, "time_total": self._time_total, "episodes_total": self._episodes_total, "ray_version": ray.__version__, } def save(self, checkpoint_dir=None) -> str: """Saves the current model state to a checkpoint. Subclasses should override ``save_checkpoint()`` instead to save state. This method dumps additional metadata alongside the saved path. If a remote checkpoint dir is given, this will also sync up to remote storage. Args: checkpoint_dir (str): Optional dir to place the checkpoint. Returns: str: path that points to xxx.pkl file. Note the return path should match up with what is expected of `restore()`. """ checkpoint_dir = TrainableUtil.make_checkpoint_dir( checkpoint_dir or self.logdir, index=self.iteration) checkpoint = self.save_checkpoint(checkpoint_dir) trainable_state = self.get_state() checkpoint_path = TrainableUtil.process_checkpoint( checkpoint, parent_dir=checkpoint_dir, trainable_state=trainable_state) self._postprocess_checkpoint(checkpoint_dir) # Maybe sync to cloud self._maybe_save_to_cloud(checkpoint_dir) return checkpoint_path def _postprocess_checkpoint(self, checkpoint_path: str): """Run extra postprocessing before the checkpoint is saved to cloud.""" pass def _maybe_save_to_cloud(self, checkpoint_dir): # Derived classes like the FunctionRunner might call this if self.uses_cloud_checkpointing: self.storage_client.sync_up(checkpoint_dir, self._storage_path(checkpoint_dir)) self.storage_client.wait_or_retry() def save_to_object(self): """Saves the current model state to a Python object. It also saves to disk but does not return the checkpoint path. Returns: Object holding checkpoint data. """ tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir) checkpoint_path = self.save(tmpdir) # Save all files in subtree and delete the tmpdir. obj = TrainableUtil.checkpoint_to_object(checkpoint_path) shutil.rmtree(tmpdir) return obj def restore(self, checkpoint_path): """Restores training state from a given model checkpoint. These checkpoints are returned from calls to save(). Subclasses should override ``load_checkpoint()`` instead to restore state. This method restores additional metadata saved with the checkpoint. `checkpoint_path` should match with the return from ``save()``. `checkpoint_path` can be `~/ray_results/exp/MyTrainable_abc/ checkpoint_00000/checkpoint`. Or, `~/ray_results/exp/MyTrainable_abc/checkpoint_00000`. `self.logdir` should generally be corresponding to `checkpoint_path`, for example, `~/ray_results/exp/MyTrainable_abc`. `self.remote_checkpoint_dir` in this case, is something like, `REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc` """ if self.uses_cloud_checkpointing: rel_checkpoint_dir = TrainableUtil.find_rel_checkpoint_dir( self.logdir, checkpoint_path) self.storage_client.sync_down( os.path.join(self.remote_checkpoint_dir, rel_checkpoint_dir), os.path.join(self.logdir, rel_checkpoint_dir), ) self.storage_client.wait_or_retry() # Ensure TrialCheckpoints are converted if isinstance(checkpoint_path, TrialCheckpoint): checkpoint_path = checkpoint_path.local_path with open(checkpoint_path + ".tune_metadata", "rb") as f: metadata = pickle.load(f) self._experiment_id = metadata["experiment_id"] self._iteration = metadata["iteration"] self._timesteps_total = metadata["timesteps_total"] self._time_total = metadata["time_total"] self._episodes_total = metadata["episodes_total"] saved_as_dict = metadata["saved_as_dict"] if saved_as_dict: with open(checkpoint_path, "rb") as loaded_state: checkpoint_dict = pickle.load(loaded_state) checkpoint_dict.update(tune_checkpoint_path=checkpoint_path) self.load_checkpoint(checkpoint_dict) else: self.load_checkpoint(checkpoint_path) self._time_since_restore = 0.0 self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = True logger.info("Restored on %s from checkpoint: %s", self.get_current_ip(), checkpoint_path) state = { "_iteration": self._iteration, "_timesteps_total": self._timesteps_total, "_time_total": self._time_total, "_episodes_total": self._episodes_total, } logger.info("Current state after restoring: %s", state) def restore_from_object(self, obj): """Restores training state from a checkpoint object. These checkpoints are returned from calls to save_to_object(). """ tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir) checkpoint_path = TrainableUtil.create_from_pickle(obj, tmpdir) self.restore(checkpoint_path) shutil.rmtree(tmpdir) def delete_checkpoint(self, checkpoint_path): """Deletes local copy of checkpoint. Args: checkpoint_path (str): Path to checkpoint. """ # Ensure TrialCheckpoints are converted if isinstance(checkpoint_path, TrialCheckpoint): checkpoint_path = checkpoint_path.local_path try: checkpoint_dir = TrainableUtil.find_checkpoint_dir(checkpoint_path) except FileNotFoundError: # The checkpoint won't exist locally if the # trial was rescheduled to another worker. logger.debug( f"Local checkpoint not found during garbage collection: " f"{self.trial_id} - {checkpoint_path}") return else: if self.uses_cloud_checkpointing: self.storage_client.delete(self._storage_path(checkpoint_dir)) self.storage_client.wait_or_retry() if os.path.exists(checkpoint_dir): shutil.rmtree(checkpoint_dir) def export_model(self, export_formats, export_dir=None): """Exports model based on export_formats. Subclasses should override _export_model() to actually export model to local directory. Args: export_formats (Union[list,str]): Format or list of (str) formats that should be exported. export_dir (str): Optional dir to place the exported model. Defaults to self.logdir. Returns: A dict that maps ExportFormats to successfully exported models. """ if isinstance(export_formats, str): export_formats = [export_formats] export_dir = export_dir or self.logdir return self._export_model(export_formats, export_dir) def reset(self, new_config, logger_creator=None): """Resets trial for use with new config. Subclasses should override reset_config() to actually reset actor behavior for the new config.""" self.config = new_config trial_info = new_config.pop(TRIAL_INFO, None) if trial_info: self._trial_info = trial_info self._result_logger.flush() self._result_logger.close() if logger_creator: logger.debug("Logger reset.") self._create_logger(new_config.copy(), logger_creator) else: logger.debug("Did not reset logger. Got: " f"trainable.reset(logger_creator={logger_creator}).") stdout_file = new_config.pop(STDOUT_FILE, None) stderr_file = new_config.pop(STDERR_FILE, None) self._close_logfiles() self._open_logfiles(stdout_file, stderr_file) success = self.reset_config(new_config) if not success: return False # Reset attributes. Will be overwritten by `restore` if a checkpoint # is provided. self._iteration = 0 self._time_total = 0.0 self._timesteps_total = None self._episodes_total = None self._time_since_restore = 0.0 self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = False return True def reset_config(self, new_config): """Resets configuration without restarting the trial. This method is optional, but can be implemented to speed up algorithms such as PBT, and to allow performance optimizations such as running experiments with reuse_actors=True. Args: new_config (dict): Updated hyperparameter configuration for the trainable. Returns: True if reset was successful else False. """ return False def _create_logger( self, config: Dict[str, Any], logger_creator: Callable[[Dict[str, Any]], Logger] = None, ): """Create logger from logger creator. Sets _logdir and _result_logger. `_logdir` is the **per trial** directory for the Trainable. """ if logger_creator: self._result_logger = logger_creator(config) self._logdir = self._result_logger.logdir else: from ray.tune.logger import UnifiedLogger logdir_prefix = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") ray._private.utils.try_to_create_directory(DEFAULT_RESULTS_DIR) self._logdir = tempfile.mkdtemp(prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR) self._result_logger = UnifiedLogger(config, self._logdir, loggers=None) def _open_logfiles(self, stdout_file, stderr_file): """Create loggers. Open stdout and stderr logfiles.""" if stdout_file: stdout_path = os.path.expanduser( os.path.join(self._logdir, stdout_file)) self._stdout_fp = open(stdout_path, "a+") self._stdout_stream = Tee(sys.stdout, self._stdout_fp) self._stdout_context = redirect_stdout(self._stdout_stream) self._stdout_context.__enter__() if stderr_file: stderr_path = os.path.expanduser( os.path.join(self._logdir, stderr_file)) self._stderr_fp = open(stderr_path, "a+") self._stderr_stream = Tee(sys.stderr, self._stderr_fp) self._stderr_context = redirect_stderr(self._stderr_stream) self._stderr_context.__enter__() # Add logging handler to root ray logger formatter = logging.Formatter("[%(levelname)s %(asctime)s] " "%(filename)s: %(lineno)d " "%(message)s") self._stderr_logging_handler = logging.StreamHandler( self._stderr_fp) self._stderr_logging_handler.setFormatter(formatter) ray.logger.addHandler(self._stderr_logging_handler) def _close_logfiles(self): """Close stdout and stderr logfiles.""" if self._stderr_logging_handler: ray.logger.removeHandler(self._stderr_logging_handler) if self._stdout_context: self._stdout_stream.flush() self._stdout_context.__exit__(None, None, None) self._stdout_fp.close() self._stdout_context = None if self._stderr_context: self._stderr_stream.flush() self._stderr_context.__exit__(None, None, None) self._stderr_fp.close() self._stderr_context = None def stop(self): """Releases all resources used by this trainable. Calls ``Trainable.cleanup`` internally. Subclasses should override ``Trainable.cleanup`` for custom cleanup procedures. """ self._result_logger.flush() self._result_logger.close() if self._monitor.is_alive(): self._monitor.stop() self._monitor.join() self.cleanup() self._close_logfiles() @property def logdir(self): """Directory of the results and checkpoints for this Trainable. Tune will automatically sync this folder with the driver if execution is distributed. Note that the current working directory will also be changed to this. """ return os.path.join(self._logdir, "") @property def trial_name(self): """Trial name for the corresponding trial of this Trainable. This is not set if not using Tune. .. code-block:: python name = self.trial_name """ if self._trial_info: return self._trial_info.trial_name else: return "default" @property def trial_id(self): """Trial ID for the corresponding trial of this Trainable. This is not set if not using Tune. .. code-block:: python trial_id = self.trial_id """ if self._trial_info: return self._trial_info.trial_id else: return "default" @property def trial_resources(self) -> Union[Resources, PlacementGroupFactory]: """Resources currently assigned to the trial of this Trainable. This is not set if not using Tune. .. code-block:: python trial_resources = self.trial_resources """ if self._trial_info: return self._trial_info.trial_resources else: return "default" @property def iteration(self): """Current training iteration. This value is automatically incremented every time `train()` is called and is automatically inserted into the training result dict. """ return self._iteration @property def training_iteration(self): """Current training iteration (same as `self.iteration`). This value is automatically incremented every time `train()` is called and is automatically inserted into the training result dict. """ return self._iteration def get_config(self): """Returns configuration passed in by Tune.""" return self.config def step(self): """Subclasses should override this to implement train(). The return value will be automatically passed to the loggers. Users can also return `tune.result.DONE` or `tune.result.SHOULD_CHECKPOINT` as a key to manually trigger termination or checkpointing of this trial. Note that manual checkpointing only works when subclassing Trainables. .. versionadded:: 0.8.7 Returns: A dict that describes training progress. """ if self._implements_method("_train") and log_once("_train"): raise DeprecationWarning( "Trainable._train is deprecated and is now removed. Override " "Trainable.step instead.") raise NotImplementedError def save_checkpoint(self, tmp_checkpoint_dir): """Subclasses should override this to implement ``save()``. Warning: Do not rely on absolute paths in the implementation of ``Trainable.save_checkpoint`` and ``Trainable.load_checkpoint``. Use ``validate_save_restore`` to catch ``Trainable.save_checkpoint``/ ``Trainable.load_checkpoint`` errors before execution. >>> from ray.tune.utils import validate_save_restore >>> validate_save_restore(MyTrainableClass) >>> validate_save_restore(MyTrainableClass, use_object_store=True) .. versionadded:: 0.8.7 Args: tmp_checkpoint_dir (str): The directory where the checkpoint file must be stored. In a Tune run, if the trial is paused, the provided path may be temporary and moved. Returns: A dict or string. If string, the return value is expected to be prefixed by `tmp_checkpoint_dir`. If dict, the return value will be automatically serialized by Tune and passed to ``Trainable.load_checkpoint()``. Examples: >>> print(trainable1.save_checkpoint("/tmp/checkpoint_1")) "/tmp/checkpoint_1/my_checkpoint_file" >>> print(trainable2.save_checkpoint("/tmp/checkpoint_2")) {"some": "data"} >>> trainable.save_checkpoint("/tmp/bad_example") "/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error. """ if self._implements_method("_save") and log_once("_save"): raise DeprecationWarning( "Trainable._save is deprecated and is now removed. Override " "Trainable.save_checkpoint instead.") raise NotImplementedError def load_checkpoint(self, checkpoint): """Subclasses should override this to implement restore(). Warning: In this method, do not rely on absolute paths. The absolute path of the checkpoint_dir used in ``Trainable.save_checkpoint`` may be changed. If ``Trainable.save_checkpoint`` returned a prefixed string, the prefix of the checkpoint string returned by ``Trainable.save_checkpoint`` may be changed. This is because trial pausing depends on temporary directories. The directory structure under the checkpoint_dir provided to ``Trainable.save_checkpoint`` is preserved. See the example below. .. code-block:: python class Example(Trainable): def save_checkpoint(self, checkpoint_path): print(checkpoint_path) return os.path.join(checkpoint_path, "my/check/point") def load_checkpoint(self, checkpoint): print(checkpoint) >>> trainer = Example() >>> obj = trainer.save_to_object() # This is used when PAUSED. <logdir>/tmpc8k_c_6hsave_to_object/checkpoint_0/my/check/point >>> trainer.restore_from_object(obj) # Note the different prefix. <logdir>/tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point .. versionadded:: 0.8.7 Args: checkpoint (str|dict): If dict, the return value is as returned by `save_checkpoint`. If a string, then it is a checkpoint path that may have a different prefix than that returned by `save_checkpoint`. The directory structure underneath the `checkpoint_dir` `save_checkpoint` is preserved. """ if self._implements_method("_restore") and log_once("_restore"): raise DeprecationWarning( "Trainable._restore is deprecated and is now removed. " "Override Trainable.load_checkpoint instead.") raise NotImplementedError def setup(self, config): """Subclasses should override this for custom initialization. .. versionadded:: 0.8.7 Args: config (dict): Hyperparameters and other configs given. Copy of `self.config`. """ if self._implements_method("_setup") and log_once("_setup"): raise DeprecationWarning( "Trainable._setup is deprecated and is now removed. Override " "Trainable.setup instead.") pass def log_result(self, result): """Subclasses can optionally override this to customize logging. The logging here is done on the worker process rather than the driver. You may want to turn off driver logging via the ``loggers`` parameter in ``tune.run`` when overriding this function. .. versionadded:: 0.8.7 Args: result (dict): Training result returned by step(). """ if self._implements_method("_log_result") and log_once("_log_result"): raise DeprecationWarning( "Trainable._log_result is deprecated and is now removed. " "Override Trainable.log_result instead.") self._result_logger.on_result(result) def cleanup(self): """Subclasses should override this for any cleanup on stop. If any Ray actors are launched in the Trainable (i.e., with a RLlib trainer), be sure to kill the Ray actor process here. You can kill a Ray actor by calling `actor.__ray_terminate__.remote()` on the actor. .. versionadded:: 0.8.7 """ if self._implements_method("_stop") and log_once("_stop"): raise DeprecationWarning( "Trainable._stop is deprecated and is now removed. Override " "Trainable.cleanup instead.") pass def _export_model(self, export_formats, export_dir): """Subclasses should override this to export model. Args: export_formats (list): List of formats that should be exported. export_dir (str): Directory to place exported models. Return: A dict that maps ExportFormats to successfully exported models. """ return {} def _implements_method(self, key): return hasattr(self, key) and callable(getattr(self, key))
class Trial(object): """A trial object holds the state for one model training run. Trials are themselves managed by the TrialRunner class, which implements the event loop for submitting trial runs to a Ray cluster. Trials start in the PENDING state, and transition to RUNNING once started. On error it transitions to ERROR, otherwise TERMINATED on success. """ PENDING = "PENDING" RUNNING = "RUNNING" PAUSED = "PAUSED" TERMINATED = "TERMINATED" ERROR = "ERROR" def __init__(self, trainable_name, config=None, trial_id=None, local_dir=DEFAULT_RESULTS_DIR, evaluated_params=None, experiment_tag="", resources=None, stopping_criterion=None, checkpoint_freq=0, checkpoint_at_end=False, keep_checkpoints_num=None, checkpoint_score_attr="", export_formats=None, restore_path=None, trial_name_creator=None, loggers=None, sync_to_driver_fn=None, max_failures=0): """Initialize a new trial. The args here take the same meaning as the command line flags defined in ray.tune.config_parser. """ validate_trainable(trainable_name) # Trial config self.trainable_name = trainable_name self.trial_id = Trial.generate_id() if trial_id is None else trial_id self.config = config or {} self.local_dir = local_dir # This remains unexpanded for syncing. #: Parameters that Tune varies across searches. self.evaluated_params = evaluated_params or {} self.experiment_tag = experiment_tag trainable_cls = self.get_trainable_cls() if trainable_cls and hasattr(trainable_cls, "default_resource_request"): default_resources = trainable_cls.default_resource_request( self.config) if default_resources: if resources: raise ValueError( "Resources for {} have been automatically set to {} " "by its `default_resource_request()` method. Please " "clear the `resources_per_trial` option.".format( trainable_cls, default_resources)) resources = default_resources self.resources = resources or Resources(cpu=1, gpu=0) self.stopping_criterion = stopping_criterion or {} self.loggers = loggers self.sync_to_driver_fn = sync_to_driver_fn self.verbose = True self.max_failures = max_failures # Local trial state that is updated during the run self.last_result = {} self.last_update_time = -float("inf") self.checkpoint_freq = checkpoint_freq self.checkpoint_at_end = checkpoint_at_end self.history = [] self.keep_checkpoints_num = keep_checkpoints_num self._cmp_greater = not checkpoint_score_attr.startswith("min-") self.best_checkpoint_attr_value = -float("inf") \ if self._cmp_greater else float("inf") # Strip off "min-" from checkpoint attribute self.checkpoint_score_attr = checkpoint_score_attr \ if self._cmp_greater else checkpoint_score_attr[4:] self._checkpoint = Checkpoint(storage=Checkpoint.DISK, value=restore_path) self.export_formats = export_formats self.status = Trial.PENDING self.logdir = None self.runner = None self.result_logger = None self.last_debug = 0 self.error_file = None self.error_msg = None self.num_failures = 0 self.custom_trial_name = None # AutoML fields self.results = None self.best_result = None self.param_config = None self.extra_arg = None self._nonjson_fields = [ "_checkpoint", "loggers", "sync_to_driver_fn", "results", "best_result", "param_config", "extra_arg", ] if trial_name_creator: self.custom_trial_name = trial_name_creator(self) @classmethod def generate_id(cls): return str(uuid.uuid1().hex)[:8] @classmethod def create_logdir(cls, identifier, local_dir): local_dir = os.path.expanduser(local_dir) if not os.path.exists(local_dir): os.makedirs(local_dir) return tempfile.mkdtemp(prefix="{}_{}".format( identifier[:MAX_LEN_IDENTIFIER], date_str()), dir=local_dir) def init_logger(self): """Init logger.""" if not self.result_logger: if not self.logdir: self.logdir = Trial.create_logdir(str(self), self.local_dir) elif not os.path.exists(self.logdir): os.makedirs(self.logdir) self.result_logger = UnifiedLogger( self.config, self.logdir, trial=self, loggers=self.loggers, sync_function=self.sync_to_driver_fn) def update_resources(self, cpu, gpu, **kwargs): """EXPERIMENTAL: Updates the resource requirements. Should only be called when the trial is not running. Raises: ValueError if trial status is running. """ if self.status is Trial.RUNNING: raise ValueError("Cannot update resources while Trial is running.") self.resources = Resources(cpu, gpu, **kwargs) def sync_logger_to_new_location(self, worker_ip): """Updates the logger location. Also pushes logdir to worker_ip, allowing for cross-node recovery. """ if self.result_logger: self.result_logger.sync_results_to_new_location(worker_ip) def close_logger(self): """Close logger.""" if self.result_logger: self.result_logger.close() self.result_logger = None def write_error_log(self, error_msg): if error_msg and self.logdir: self.num_failures += 1 # may be moved to outer scope? error_file = os.path.join(self.logdir, "error_{}.txt".format(date_str())) with open(error_file, "w") as f: f.write(error_msg) self.error_file = error_file self.error_msg = error_msg def should_stop(self, result): """Whether the given result meets this trial's stopping criteria.""" if result.get(DONE): return True if callable(self.stopping_criterion): return self.stopping_criterion(self.trial_id, result) for criteria, stop_value in self.stopping_criterion.items(): if criteria not in result: raise TuneError( "Stopping criteria {} not provided in result {}.".format( criteria, result)) elif isinstance(criteria, dict): raise ValueError( "Stopping criteria is now flattened by default. " "Use forward slashes to nest values `key1/key2/key3`.") elif result[criteria] >= stop_value: return True return False def should_checkpoint(self): """Whether this trial is due for checkpointing.""" result = self.last_result or {} if result.get(DONE) and self.checkpoint_at_end: return True if self.checkpoint_freq: return result.get(TRAINING_ITERATION, 0) % self.checkpoint_freq == 0 else: return False def has_checkpoint(self): return self._checkpoint.value is not None def clear_checkpoint(self): self._checkpoint.value = None def should_recover(self): """Returns whether the trial qualifies for retrying. This is if the trial has not failed more than max_failures. Note this may return true even when there is no checkpoint, either because `self.checkpoint_freq` is `0` or because the trial failed before a checkpoint has been made. """ return self.num_failures < self.max_failures or self.max_failures < 0 def update_last_result(self, result, terminate=False): result.update(trial_id=self.trial_id, done=terminate) if self.experiment_tag: result.update(experiment_tag=self.experiment_tag) if self.verbose and (terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL): print("Result for {}:".format(self)) print(" {}".format(pretty_print(result).replace("\n", "\n "))) self.last_debug = time.time() self.last_result = result self.last_update_time = time.time() self.result_logger.on_result(self.last_result) def compare_checkpoints(self, attr_mean): """Compares two checkpoints based on the attribute attr_mean param. Greater than is used by default. If command-line parameter checkpoint_score_attr starts with "min-" less than is used. Arguments: attr_mean: mean of attribute value for the current checkpoint Returns: True: when attr_mean is greater than previous checkpoint attr_mean and greater than function is selected when attr_mean is less than previous checkpoint attr_mean and less than function is selected False: when attr_mean is not in alignment with selected cmp fn """ if self._cmp_greater and attr_mean > self.best_checkpoint_attr_value: return True elif (not self._cmp_greater and attr_mean < self.best_checkpoint_attr_value): return True return False def get_trainable_cls(self): return get_trainable_cls(self.trainable_name) def set_verbose(self, verbose): self.verbose = verbose def is_finished(self): return self.status in [Trial.TERMINATED, Trial.ERROR] @property def node_ip(self): return self.last_result.get("node_ip") def __repr__(self): return str(self) def __str__(self): """Combines ``env`` with ``trainable_name`` and ``trial_id``. Can be overriden with a custom string creator. """ if self.custom_trial_name: return self.custom_trial_name if "env" in self.config: env = self.config["env"] if isinstance(env, type): env = env.__name__ identifier = "{}_{}".format(self.trainable_name, env) else: identifier = self.trainable_name identifier += "_" + self.trial_id return identifier.replace("/", "_") def __getstate__(self): """Memento generator for Trial. Sets RUNNING trials to PENDING, and flushes the result logger. Note this can only occur if the trial holds a DISK checkpoint. """ assert self._checkpoint.storage == Checkpoint.DISK, ( "Checkpoint must not be in-memory.") state = self.__dict__.copy() state["resources"] = resources_to_json(self.resources) for key in self._nonjson_fields: state[key] = binary_to_hex(cloudpickle.dumps(state.get(key))) state["runner"] = None state["result_logger"] = None if self.result_logger: self.result_logger.flush() state["__logger_started__"] = True else: state["__logger_started__"] = False return copy.deepcopy(state) def __setstate__(self, state): logger_started = state.pop("__logger_started__") state["resources"] = json_to_resources(state["resources"]) if state["status"] == Trial.RUNNING: state["status"] = Trial.PENDING for key in self._nonjson_fields: state[key] = cloudpickle.loads(hex_to_binary(state[key])) self.__dict__.update(state) validate_trainable(self.trainable_name) if logger_started: self.init_logger()
class Trial(object): """A trial object holds the state for one model training run. Trials are themselves managed by the TrialRunner class, which implements the event loop for submitting trial runs to a Ray cluster. Trials start in the PENDING state, and transition to RUNNING once started. On error it transitions to ERROR, otherwise TERMINATED on success. """ PENDING = "PENDING" RUNNING = "RUNNING" PAUSED = "PAUSED" TERMINATED = "TERMINATED" ERROR = "ERROR" def __init__(self, trainable_name, config=None, trial_id=None, local_dir=DEFAULT_RESULTS_DIR, experiment_tag="", resources=None, stopping_criterion=None, checkpoint_freq=0, checkpoint_at_end=False, keep_checkpoints_num=None, checkpoint_score_attr="", export_formats=None, restore_path=None, upload_dir=None, trial_name_creator=None, loggers=None, sync_function=None, max_failures=0): """Initialize a new trial. The args here take the same meaning as the command line flags defined in ray.tune.config_parser. """ Trial._registration_check(trainable_name) # Trial config self.trainable_name = trainable_name self.config = config or {} self.local_dir = os.path.expanduser(local_dir) self.experiment_tag = experiment_tag self.resources = ( resources or self._get_trainable_cls().default_resource_request(self.config)) self.stopping_criterion = stopping_criterion or {} self.upload_dir = upload_dir self.loggers = loggers self.sync_function = sync_function validate_sync_function(sync_function) self.verbose = True self.max_failures = max_failures # Local trial state that is updated during the run self.last_result = {} self.last_update_time = -float("inf") self.checkpoint_freq = checkpoint_freq self.checkpoint_at_end = checkpoint_at_end self.history = [] self.keep_checkpoints_num = keep_checkpoints_num self._cmp_greater = not checkpoint_score_attr.startswith("min-") self.best_checkpoint_attr_value = -float("inf") \ if self._cmp_greater else float("inf") # Strip off "min-" from checkpoint attribute self.checkpoint_score_attr = checkpoint_score_attr \ if self._cmp_greater else checkpoint_score_attr[4:] self._checkpoint = Checkpoint(storage=Checkpoint.DISK, value=restore_path) self.export_formats = export_formats self.status = Trial.PENDING self.logdir = None self.runner = None self.result_logger = None self.last_debug = 0 self.trial_id = Trial.generate_id() if trial_id is None else trial_id self.error_file = None self.num_failures = 0 self.custom_trial_name = None # AutoML fields self.results = None self.best_result = None self.param_config = None self.extra_arg = None self._nonjson_fields = [ "_checkpoint", "loggers", "sync_function", "results", "best_result", "param_config", "extra_arg", ] if trial_name_creator: self.custom_trial_name = trial_name_creator(self) @classmethod def _registration_check(cls, trainable_name): if not has_trainable(trainable_name): # Make sure rllib agents are registered from ray import rllib # noqa: F401 if not has_trainable(trainable_name): raise TuneError("Unknown trainable: " + trainable_name) @classmethod def generate_id(cls): return binary_to_hex(_random_string())[:8] def init_logger(self): """Init logger.""" if not self.result_logger: if not os.path.exists(self.local_dir): os.makedirs(self.local_dir) if not self.logdir: self.logdir = tempfile.mkdtemp(prefix="{}_{}".format( str(self)[:MAX_LEN_IDENTIFIER], date_str()), dir=self.local_dir) elif not os.path.exists(self.logdir): os.makedirs(self.logdir) self.result_logger = UnifiedLogger( self.config, self.logdir, upload_uri=self.upload_dir, loggers=self.loggers, sync_function=self.sync_function) def update_resources(self, cpu, gpu, **kwargs): """EXPERIMENTAL: Updates the resource requirements. Should only be called when the trial is not running. Raises: ValueError if trial status is running. """ if self.status is Trial.RUNNING: raise ValueError("Cannot update resources while Trial is running.") self.resources = Resources(cpu, gpu, **kwargs) def sync_logger_to_new_location(self, worker_ip): """Updates the logger location. Also pushes logdir to worker_ip, allowing for cross-node recovery. """ if self.result_logger: self.result_logger.sync_results_to_new_location(worker_ip) def close_logger(self): """Close logger.""" if self.result_logger: self.result_logger.close() self.result_logger = None def write_error_log(self, error_msg): if error_msg and self.logdir: self.num_failures += 1 # may be moved to outer scope? error_file = os.path.join(self.logdir, "error_{}.txt".format(date_str())) with open(error_file, "w") as f: f.write(error_msg) self.error_file = error_file def should_stop(self, result): """Whether the given result meets this trial's stopping criteria.""" if result.get(DONE): return True for criteria, stop_value in self.stopping_criterion.items(): if criteria not in result: raise TuneError( "Stopping criteria {} not provided in result {}.".format( criteria, result)) if result[criteria] >= stop_value: return True return False def should_checkpoint(self): """Whether this trial is due for checkpointing.""" result = self.last_result or {} if result.get(DONE) and self.checkpoint_at_end: return True if self.checkpoint_freq: return result.get(TRAINING_ITERATION, 0) % self.checkpoint_freq == 0 else: return False def progress_string(self): """Returns a progress message for printing out to the console.""" if not self.last_result: return self._status_string() def location_string(hostname, pid): if hostname == os.uname()[1]: return 'pid={}'.format(pid) else: return '{} pid={}'.format(hostname, pid) pieces = [ '{}'.format(self._status_string()), '[{}]'.format(self.resources.summary_string()), '[{}]'.format( location_string(self.last_result.get(HOSTNAME), self.last_result.get(PID))), '{} s'.format(int(self.last_result.get(TIME_TOTAL_S))) ] if self.last_result.get(TRAINING_ITERATION) is not None: pieces.append('{} iter'.format( self.last_result[TRAINING_ITERATION])) if self.last_result.get(TIMESTEPS_TOTAL) is not None: pieces.append('{} ts'.format(self.last_result[TIMESTEPS_TOTAL])) if self.last_result.get(EPISODE_REWARD_MEAN) is not None: pieces.append('{} rew'.format( format(self.last_result[EPISODE_REWARD_MEAN], '.3g'))) if self.last_result.get(MEAN_LOSS) is not None: pieces.append('{} loss'.format( format(self.last_result[MEAN_LOSS], '.3g'))) if self.last_result.get(MEAN_ACCURACY) is not None: pieces.append('{} acc'.format( format(self.last_result[MEAN_ACCURACY], '.3g'))) return ', '.join(pieces) def _status_string(self): return "{}{}".format( self.status, ", {} failures: {}".format( self.num_failures, self.error_file) if self.error_file else "") def has_checkpoint(self): return self._checkpoint.value is not None def clear_checkpoint(self): self._checkpoint.value = None def should_recover(self): """Returns whether the trial qualifies for restoring. This is if a checkpoint frequency is set and has not failed more than max_failures. This may return true even when there may not yet be a checkpoint. """ return (self.checkpoint_freq > 0 and (self.num_failures < self.max_failures or self.max_failures < 0)) def update_last_result(self, result, terminate=False): if terminate: result.update(done=True) if self.verbose and (terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL): print("Result for {}:".format(self)) print(" {}".format(pretty_print(result).replace("\n", "\n "))) self.last_debug = time.time() self.last_result = result self.last_update_time = time.time() self.result_logger.on_result(self.last_result) def compare_checkpoints(self, attr_mean): """Compares two checkpoints based on the attribute attr_mean param. Greater than is used by default. If command-line parameter checkpoint_score_attr starts with "min-" less than is used. Arguments: attr_mean: mean of attribute value for the current checkpoint Returns: True: when attr_mean is greater than previous checkpoint attr_mean and greater than function is selected when attr_mean is less than previous checkpoint attr_mean and less than function is selected False: when attr_mean is not in alignment with selected cmp fn """ if self._cmp_greater and attr_mean > self.best_checkpoint_attr_value: return True elif (not self._cmp_greater and attr_mean < self.best_checkpoint_attr_value): return True return False def _get_trainable_cls(self): return ray.tune.registry._global_registry.get( ray.tune.registry.TRAINABLE_CLASS, self.trainable_name) def set_verbose(self, verbose): self.verbose = verbose def is_finished(self): return self.status in [Trial.TERMINATED, Trial.ERROR] def __repr__(self): return str(self) def __str__(self): """Combines ``env`` with ``trainable_name`` and ``experiment_tag``. Can be overriden with a custom string creator. """ if self.custom_trial_name: return self.custom_trial_name if "env" in self.config: env = self.config["env"] if isinstance(env, type): env = env.__name__ identifier = "{}_{}".format(self.trainable_name, env) else: identifier = self.trainable_name if self.experiment_tag: identifier += "_" + self.experiment_tag return identifier.replace("/", "_") def __getstate__(self): """Memento generator for Trial. Sets RUNNING trials to PENDING, and flushes the result logger. Note this can only occur if the trial holds a DISK checkpoint. """ assert self._checkpoint.storage == Checkpoint.DISK, ( "Checkpoint must not be in-memory.") state = self.__dict__.copy() state["resources"] = resources_to_json(self.resources) for key in self._nonjson_fields: state[key] = binary_to_hex(cloudpickle.dumps(state.get(key))) state["runner"] = None state["result_logger"] = None if self.result_logger: self.result_logger.flush() state["__logger_started__"] = True else: state["__logger_started__"] = False return copy.deepcopy(state) def __setstate__(self, state): logger_started = state.pop("__logger_started__") state["resources"] = json_to_resources(state["resources"]) if state["status"] == Trial.RUNNING: state["status"] = Trial.PENDING for key in self._nonjson_fields: state[key] = cloudpickle.loads(hex_to_binary(state[key])) self.__dict__.update(state) Trial._registration_check(self.trainable_name) if logger_started: self.init_logger()
class TrackSession: """Manages results for a single session. Represents a single Trial in an experiment. This is automatically created when using ``tune.run``. Attributes: trial_name (str): Custom trial name. experiment_dir (str): Directory where results for all trials are stored. Each session is stored into a unique directory inside experiment_dir. upload_dir (str): Directory to sync results to. trial_config (dict): Parameters that will be logged to disk. _tune_reporter (StatusReporter): For rerouting when using Tune. Will not instantiate logging if not None. """ def __init__(self, trial_name=None, experiment_dir=None, upload_dir=None, trial_config=None, _tune_reporter=None): self._experiment_dir = None self._logdir = None self._upload_dir = None self.trial_config = None self._iteration = -1 self.is_tune_session = bool(_tune_reporter) if self.is_tune_session: self._logger = _ReporterHook(_tune_reporter) self._logdir = _tune_reporter.logdir self._trial_name = _tune_reporter.trial_name self._trial_id = _tune_reporter.trial_id else: self._trial_id = Trial.generate_id() self._trial_name = trial_name or self._trial_id self._initialize_logging(experiment_dir, upload_dir, trial_config) def _initialize_logging(self, experiment_dir=None, upload_dir=None, trial_config=None): if upload_dir: raise NotImplementedError("Upload Dir is not yet implemented.") # TODO(rliaw): In other parts of the code, this is `local_dir`. if experiment_dir is None: experiment_dir = os.path.join(DEFAULT_RESULTS_DIR, "default") self._experiment_dir = os.path.expanduser(experiment_dir) # TODO(rliaw): Refactor `logdir` to `trial_dir`. self._logdir = Trial.create_logdir(self.trial_name, self._experiment_dir) self._upload_dir = upload_dir self.trial_config = trial_config or {} # misc metadata to save as well self.trial_config["trial_id"] = self.trial_id self._logger = UnifiedLogger(self.trial_config, self._logdir) def log(self, **metrics): """Logs all named arguments specified in `metrics`. This will log trial metrics locally, and they will be synchronized with the driver periodically through ray. Arguments: metrics: named arguments with corresponding values to log. """ self._iteration += 1 # TODO: Implement a batching mechanism for multiple calls to `log` # within the same iteration. metrics_dict = metrics.copy() metrics_dict.update({"trial_id": self.trial_id}) # TODO: Move Trainable autopopulation to a util function metrics_dict.setdefault(TRAINING_ITERATION, self._iteration) self._logger.on_result(metrics_dict) def close(self): """Closes loggers. No need to call this when using ``tune.run``. """ self.trial_config["trial_completed"] = True self.trial_config["end_time"] = datetime.now().isoformat() # TODO(rliaw): Have Tune support updated configs self._logger.update_config(self.trial_config) self._logger.flush() self._logger.close() @property def logdir(self): """Trial logdir (subdir of given experiment directory)""" return self._logdir @property def trial_name(self): """Trial name for the corresponding trial of this Trainable""" return self._trial_name @property def trial_id(self): """Trial id for the corresponding trial of this Trainable""" return self._trial_id
class Trial(object): """A trial object holds the state for one model training run. Trials are themselves managed by the TrialRunner class, which implements the event loop for submitting trial runs to a Ray cluster. Trials start in the PENDING state, and transition to RUNNING once started. On error it transitions to ERROR, otherwise TERMINATED on success. """ PENDING = "PENDING" RUNNING = "RUNNING" PAUSED = "PAUSED" TERMINATED = "TERMINATED" ERROR = "ERROR" def __init__(self, trainable_name, config=None, trial_id=None, local_dir=DEFAULT_RESULTS_DIR, experiment_tag="", resources=None, stopping_criterion=None, checkpoint_freq=0, checkpoint_at_end=False, restore_path=None, upload_dir=None, trial_name_creator=None, custom_loggers=None, sync_function=None, max_failures=0): """Initialize a new trial. The args here take the same meaning as the command line flags defined in ray.tune.config_parser. """ Trial._registration_check(trainable_name) # Trial config self.trainable_name = trainable_name self.config = config or {} self.local_dir = os.path.expanduser(local_dir) self.experiment_tag = experiment_tag self.resources = ( resources or self._get_trainable_cls().default_resource_request(self.config)) self.stopping_criterion = stopping_criterion or {} self.upload_dir = upload_dir self.custom_loggers = custom_loggers self.sync_function = sync_function validate_sync_function(sync_function) self.verbose = True self.max_failures = max_failures # Local trial state that is updated during the run self.last_result = None self.last_update_time = -float("inf") self.checkpoint_freq = checkpoint_freq self.checkpoint_at_end = checkpoint_at_end self._checkpoint = Checkpoint(storage=Checkpoint.DISK, value=restore_path) self.status = Trial.PENDING self.location = None self.logdir = None self.result_logger = None self.last_debug = 0 self.trial_id = Trial.generate_id() if trial_id is None else trial_id self.error_file = None self.num_failures = 0 self.trial_name = None if trial_name_creator: self.trial_name = trial_name_creator(self) @classmethod def _registration_check(cls, trainable_name): if not has_trainable(trainable_name): # Make sure rllib agents are registered from ray import rllib # noqa: F401 if not has_trainable(trainable_name): raise TuneError("Unknown trainable: " + trainable_name) @classmethod def generate_id(cls): return binary_to_hex(random_string())[:8] def init_logger(self): """Init logger.""" if not self.result_logger: if not os.path.exists(self.local_dir): os.makedirs(self.local_dir) if not self.logdir: self.logdir = tempfile.mkdtemp(prefix="{}_{}".format( str(self)[:MAX_LEN_IDENTIFIER], date_str()), dir=self.local_dir) elif not os.path.exists(self.logdir): os.makedirs(self.logdir) self.result_logger = UnifiedLogger( self.config, self.logdir, upload_uri=self.upload_dir, custom_loggers=self.custom_loggers, sync_function=self.sync_function) def close_logger(self): """Close logger.""" if self.result_logger: self.result_logger.close() self.result_logger = None def write_error_log(self, error_msg): if error_msg and self.logdir: self.num_failures += 1 # may be moved to outer scope? error_file = os.path.join(self.logdir, "error_{}.txt".format(date_str())) with open(error_file, "w") as f: f.write(error_msg) self.error_file = error_file def should_stop(self, result): """Whether the given result meets this trial's stopping criteria.""" if result.get(DONE): return True for criteria, stop_value in self.stopping_criterion.items(): if criteria not in result: raise TuneError( "Stopping criteria {} not provided in result {}.".format( criteria, result)) if result[criteria] >= stop_value: return True return False def should_checkpoint(self): """Whether this trial is due for checkpointing.""" result = self.last_result or {} if result.get(DONE) and self.checkpoint_at_end: return True if self.checkpoint_freq: return result.get(TRAINING_ITERATION, 0) % self.checkpoint_freq == 0 else: return False def progress_string(self): """Returns a progress message for printing out to the console.""" if self.last_result is None: return self._status_string() def location_string(hostname, pid): if hostname == os.uname()[1]: return 'pid={}'.format(pid) else: return '{} pid={}'.format(hostname, pid) pieces = [ '{} [{}]'.format( self._status_string(), location_string(self.last_result.get(HOSTNAME), self.last_result.get(PID))), '{} s'.format(int(self.last_result.get(TIME_TOTAL_S))) ] if self.last_result.get(TRAINING_ITERATION) is not None: pieces.append('{} iter'.format( self.last_result[TRAINING_ITERATION])) if self.last_result.get(TIMESTEPS_TOTAL) is not None: pieces.append('{} ts'.format(self.last_result[TIMESTEPS_TOTAL])) if self.last_result.get("episode_reward_mean") is not None: pieces.append('{} rew'.format( format(self.last_result["episode_reward_mean"], '.3g'))) if self.last_result.get("mean_loss") is not None: pieces.append('{} loss'.format( format(self.last_result["mean_loss"], '.3g'))) if self.last_result.get("mean_accuracy") is not None: pieces.append('{} acc'.format( format(self.last_result["mean_accuracy"], '.3g'))) return ', '.join(pieces) def _status_string(self): return "{}{}".format( self.status, ", {} failures: {}".format( self.num_failures, self.error_file) if self.error_file else "") def has_checkpoint(self): return self._checkpoint.value is not None def should_recover(self): """Returns whether the trial qualifies for restoring. This is if a checkpoint frequency is set and has not failed more than max_failures. This may return true even when there may not yet be a checkpoint. """ return (self.checkpoint_freq > 0 and self.num_failures < self.max_failures) def update_last_result(self, result, terminate=False): if terminate: result.update(done=True) if self.verbose and (terminate or time.time() - self.last_debug > DEBUG_PRINT_INTERVAL): logger.info("Result for {}:".format(self)) logger.info(" {}".format( pretty_print(result).replace("\n", "\n "))) self.last_debug = time.time() self.last_result = result self.last_update_time = time.time() self.result_logger.on_result(self.last_result) def _get_trainable_cls(self): return ray.tune.registry._global_registry.get( ray.tune.registry.TRAINABLE_CLASS, self.trainable_name) def set_verbose(self, verbose): self.verbose = verbose def is_finished(self): return self.status in [Trial.TERMINATED, Trial.ERROR] def __repr__(self): return str(self) def __str__(self): """Combines ``env`` with ``trainable_name`` and ``experiment_tag``. Can be overriden with a custom string creator. """ if self.trial_name: return self.trial_name if "env" in self.config: env = self.config["env"] if isinstance(env, type): env = env.__name__ identifier = "{}_{}".format(self.trainable_name, env) else: identifier = self.trainable_name if self.experiment_tag: identifier += "_" + self.experiment_tag return identifier.replace("/", "_") def __getstate__(self): """Memento generator for Trial. Sets RUNNING trials to PENDING, and flushes the result logger. Note this can only occur if the trial holds a DISK checkpoint. """ assert self._checkpoint.storage == Checkpoint.DISK, ( "Checkpoint must not be in-memory.") state = self.__dict__.copy() state["resources"] = resources_to_json(self.resources) pickle_data = { "_checkpoint": self._checkpoint, "config": self.config, "custom_loggers": self.custom_loggers, "sync_function": self.sync_function } for key, value in pickle_data.items(): state[key] = binary_to_hex(cloudpickle.dumps(value)) state["runner"] = None state["result_logger"] = None if self.status == Trial.RUNNING: state["status"] = Trial.PENDING if self.result_logger: self.result_logger.flush() state["__logger_started__"] = True else: state["__logger_started__"] = False return copy.deepcopy(state) def __setstate__(self, state): logger_started = state.pop("__logger_started__") state["resources"] = json_to_resources(state["resources"]) for key in [ "_checkpoint", "config", "custom_loggers", "sync_function" ]: state[key] = cloudpickle.loads(hex_to_binary(state[key])) self.__dict__.update(state) Trial._registration_check(self.trainable_name) if logger_started: self.init_logger()
class Trainable(object): """Abstract class for trainable models, functions, etc. A call to ``train()`` on a trainable will execute one logical iteration of training. As a rule of thumb, the execution time of one train call should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes). Calling ``save()`` should save the training state of a trainable to disk, and ``restore(path)`` should restore a trainable to the given state. Generally you only need to implement ``_setup``, ``_train``, ``_save``, and ``_restore`` when subclassing Trainable. Other implementation methods that may be helpful to override are ``_log_result``, ``reset_config``, ``_stop``, and ``_export_model``. When using Tune, Tune will convert this class into a Ray actor, which runs on a separate process. Tune will also change the current working directory of this process to `self.logdir`. """ def __init__(self, config=None, logger_creator=None): """Initialize an Trainable. Sets up logging and points ``self.logdir`` to a directory in which training outputs should be placed. Subclasses should prefer defining ``_setup()`` instead of overriding ``__init__()`` directly. Args: config (dict): Trainable-specific configuration data. By default will be saved as ``self.config``. logger_creator (func): Function that creates a ray.tune.Logger object. If unspecified, a default logger is created. """ self._experiment_id = uuid.uuid4().hex self.config = config or {} if logger_creator: self._result_logger = logger_creator(self.config) self._logdir = self._result_logger.logdir else: logdir_prefix = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") if not os.path.exists(DEFAULT_RESULTS_DIR): os.makedirs(DEFAULT_RESULTS_DIR) self._logdir = tempfile.mkdtemp(prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR) self._result_logger = UnifiedLogger(self.config, self._logdir, loggers=None) self._iteration = 0 self._time_total = 0.0 self._timesteps_total = None self._episodes_total = None self._time_since_restore = 0.0 self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = False start_time = time.time() self._setup(copy.deepcopy(self.config)) setup_time = time.time() - start_time if setup_time > SETUP_TIME_THRESHOLD: logger.info("_setup took {:.3f} seconds. If your trainable is " "slow to initialize, consider setting " "reuse_actors=True to reduce actor creation " "overheads.".format(setup_time)) self._local_ip = ray.services.get_node_ip_address() log_sys_usage = self.config.get("log_sys_usage", False) self._monitor = UtilMonitor(start=log_sys_usage) @classmethod def default_resource_request(cls, config): """Returns the resource requirement for the given configuration. This can be overriden by sub-classes to set the correct trial resource allocation, so the user does not need to. Example: >>> def default_resource_request(cls, config): >>> return Resources( >>> cpu=0, >>> gpu=0, >>> extra_cpu=config["workers"], >>> extra_gpu=int(config["use_gpu"]) * config["workers"]) """ return None @classmethod def resource_help(cls, config): """Returns a help string for configuring this trainable's resources.""" return "" def current_ip(self): logger.warning("Getting current IP.") self._local_ip = ray.services.get_node_ip_address() return self._local_ip def train(self): """Runs one logical iteration of training. Subclasses should override ``_train()`` instead to return results. This class automatically fills the following fields in the result: `done` (bool): training is terminated. Filled only if not provided. `time_this_iter_s` (float): Time in seconds this iteration took to run. This may be overriden in order to override the system-computed time difference. `time_total_s` (float): Accumulated time in seconds for this entire experiment. `experiment_id` (str): Unique string identifier for this experiment. This id is preserved across checkpoint / restore calls. `training_iteration` (int): The index of this training iteration, e.g. call to train(). This is incremented after `_train()` is called. `pid` (str): The pid of the training process. `date` (str): A formatted date of when the result was processed. `timestamp` (str): A UNIX timestamp of when the result was processed. `hostname` (str): Hostname of the machine hosting the training process. `node_ip` (str): Node ip of the machine hosting the training process. Returns: A dict that describes training progress. """ start = time.time() result = self._train() assert isinstance(result, dict), "_train() needs to return a dict." # We do not modify internal state nor update this result if duplicate. if RESULT_DUPLICATE in result: return result result = result.copy() self._iteration += 1 self._iterations_since_restore += 1 if result.get(TIME_THIS_ITER_S) is not None: time_this_iter = result[TIME_THIS_ITER_S] else: time_this_iter = time.time() - start self._time_total += time_this_iter self._time_since_restore += time_this_iter result.setdefault(DONE, False) # self._timesteps_total should only be tracked if increments provided if result.get(TIMESTEPS_THIS_ITER) is not None: if self._timesteps_total is None: self._timesteps_total = 0 self._timesteps_total += result[TIMESTEPS_THIS_ITER] self._timesteps_since_restore += result[TIMESTEPS_THIS_ITER] # self._episodes_total should only be tracked if increments provided if result.get(EPISODES_THIS_ITER) is not None: if self._episodes_total is None: self._episodes_total = 0 self._episodes_total += result[EPISODES_THIS_ITER] # self._timesteps_total should not override user-provided total result.setdefault(TIMESTEPS_TOTAL, self._timesteps_total) result.setdefault(EPISODES_TOTAL, self._episodes_total) result.setdefault(TRAINING_ITERATION, self._iteration) # Provides auto-filled neg_mean_loss for avoiding regressions if result.get("mean_loss"): result.setdefault("neg_mean_loss", -result["mean_loss"]) now = datetime.today() result.update(experiment_id=self._experiment_id, date=now.strftime("%Y-%m-%d_%H-%M-%S"), timestamp=int(time.mktime(now.timetuple())), time_this_iter_s=time_this_iter, time_total_s=self._time_total, pid=os.getpid(), hostname=os.uname()[1], node_ip=self._local_ip, config=self.config, time_since_restore=self._time_since_restore, timesteps_since_restore=self._timesteps_since_restore, iterations_since_restore=self._iterations_since_restore) monitor_data = self._monitor.get_data() if monitor_data: result.update(monitor_data) self._log_result(result) return result def save(self, checkpoint_dir=None): """Saves the current model state to a checkpoint. Subclasses should override ``_save()`` instead to save state. This method dumps additional metadata alongside the saved path. Args: checkpoint_dir (str): Optional dir to place the checkpoint. Returns: Checkpoint path or prefix that may be passed to restore(). """ checkpoint_dir = os.path.join(checkpoint_dir or self.logdir, "checkpoint_{}".format(self._iteration)) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) checkpoint = self._save(checkpoint_dir) saved_as_dict = False if isinstance(checkpoint, string_types): if not checkpoint.startswith(checkpoint_dir): raise ValueError( "The returned checkpoint path must be within the " "given checkpoint dir {}: {}".format( checkpoint_dir, checkpoint)) checkpoint_path = checkpoint elif isinstance(checkpoint, dict): saved_as_dict = True checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") with open(checkpoint_path, "wb") as f: pickle.dump(checkpoint, f) else: raise ValueError("Returned unexpected type {}. " "Expected str or dict.".format(type(checkpoint))) with open(checkpoint_path + ".tune_metadata", "wb") as f: pickle.dump( { "experiment_id": self._experiment_id, "iteration": self._iteration, "timesteps_total": self._timesteps_total, "time_total": self._time_total, "episodes_total": self._episodes_total, "saved_as_dict": saved_as_dict, "ray_version": ray.__version__, }, f) return checkpoint_path def save_to_object(self): """Saves the current model state to a Python object. It also saves to disk but does not return the checkpoint path. Returns: Object holding checkpoint data. """ tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir) checkpoint_path = self.save(tmpdir) # Save all files in subtree. data = {} for basedir, _, file_names in os.walk(tmpdir): for file_name in file_names: path = os.path.join(basedir, file_name) with open(path, "rb") as f: data[os.path.relpath(path, tmpdir)] = f.read() out = io.BytesIO() data_dict = pickle.dumps({ "checkpoint_name": os.path.relpath(checkpoint_path, tmpdir), "data": data, }) if len(data_dict) > 10e6: # getting pretty large logger.info("Checkpoint size is {} bytes".format(len(data_dict))) out.write(data_dict) shutil.rmtree(tmpdir) return out.getvalue() def restore(self, checkpoint_path): """Restores training state from a given model checkpoint. These checkpoints are returned from calls to save(). Subclasses should override ``_restore()`` instead to restore state. This method restores additional metadata saved with the checkpoint. """ with open(checkpoint_path + ".tune_metadata", "rb") as f: metadata = pickle.load(f) self._experiment_id = metadata["experiment_id"] self._iteration = metadata["iteration"] self._timesteps_total = metadata["timesteps_total"] self._time_total = metadata["time_total"] self._episodes_total = metadata["episodes_total"] saved_as_dict = metadata["saved_as_dict"] if saved_as_dict: with open(checkpoint_path, "rb") as loaded_state: checkpoint_dict = pickle.load(loaded_state) checkpoint_dict.update(tune_checkpoint_path=checkpoint_path) self._restore(checkpoint_dict) else: self._restore(checkpoint_path) self._time_since_restore = 0.0 self._timesteps_since_restore = 0 self._iterations_since_restore = 0 self._restored = True logger.info("Restored from checkpoint: %s", checkpoint_path) state = { "_iteration": self._iteration, "_timesteps_total": self._timesteps_total, "_time_total": self._time_total, "_episodes_total": self._episodes_total, } logger.info("Current state after restoring: {}".format(state)) def restore_from_object(self, obj): """Restores training state from a checkpoint object. These checkpoints are returned from calls to save_to_object(). """ info = pickle.loads(obj) data = info["data"] tmpdir = tempfile.mkdtemp("restore_from_object", dir=self.logdir) checkpoint_path = os.path.join(tmpdir, info["checkpoint_name"]) for relpath_name, file_contents in data.items(): path = os.path.join(tmpdir, relpath_name) # This may be a subdirectory, hence not just using tmpdir if not os.path.exists(os.path.dirname(path)): os.makedirs(os.path.dirname(path)) with open(path, "wb") as f: f.write(file_contents) self.restore(checkpoint_path) shutil.rmtree(tmpdir) def export_model(self, export_formats, export_dir=None): """Exports model based on export_formats. Subclasses should override _export_model() to actually export model to local directory. Args: export_formats (list): List of formats that should be exported. export_dir (str): Optional dir to place the exported model. Defaults to self.logdir. Returns: A dict that maps ExportFormats to successfully exported models. """ export_dir = export_dir or self.logdir return self._export_model(export_formats, export_dir) def reset_config(self, new_config): """Resets configuration without restarting the trial. This method is optional, but can be implemented to speed up algorithms such as PBT, and to allow performance optimizations such as running experiments with reuse_actors=True. Note that self.config need to be updated to reflect the latest parameter information in Ray logs. Args: new_config (dir): Updated hyperparameter configuration for the trainable. Returns: True if reset was successful else False. """ return False def stop(self): """Releases all resources used by this trainable.""" self._result_logger.flush() self._result_logger.close() self._stop() @property def logdir(self): """Directory of the results and checkpoints for this Trainable. Tune will automatically sync this folder with the driver if execution is distributed. Note that the current working directory will also be changed to this. """ return self._logdir @property def iteration(self): """Current training iteration. This value is automatically incremented every time `train()` is called and is automatically inserted into the training result dict. """ return self._iteration def get_config(self): """Returns configuration passed in by Tune.""" return self.config def _train(self): """Subclasses should override this to implement train(). The return value will be automatically passed to the loggers. Users can also return `tune.result.DONE` or `tune.result.SHOULD_CHECKPOINT` as a key to manually trigger termination or checkpointing of this trial. Note that manual checkpointing only works when subclassing Trainables. Returns: A dict that describes training progress. """ raise NotImplementedError def _save(self, tmp_checkpoint_dir): """Subclasses should override this to implement ``save()``. Warning: Do not rely on absolute paths in the implementation of ``_save`` and ``_restore``. Use ``validate_save_restore`` to catch ``_save``/``_restore`` errors before execution. >>> from ray.tune.util import validate_save_restore >>> validate_save_restore(MyTrainableClass) >>> validate_save_restore(MyTrainableClass, use_object_store=True) Args: tmp_checkpoint_dir (str): The directory where the checkpoint file must be stored. In a Tune run, if the trial is paused, the provided path may be temporary and moved. Returns: A dict or string. If string, the return value is expected to be prefixed by `tmp_checkpoint_dir`. If dict, the return value will be automatically serialized by Tune and passed to `_restore()`. Examples: >>> print(trainable1._save("/tmp/checkpoint_1")) "/tmp/checkpoint_1/my_checkpoint_file" >>> print(trainable2._save("/tmp/checkpoint_2")) {"some": "data"} >>> trainable._save("/tmp/bad_example") "/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error. """ raise NotImplementedError def _restore(self, checkpoint): """Subclasses should override this to implement restore(). Warning: In this method, do not rely on absolute paths. The absolute path of the checkpoint_dir used in ``_save`` may be changed. If ``_save`` returned a prefixed string, the prefix of the checkpoint string returned by ``_save`` may be changed. This is because trial pausing depends on temporary directories. The directory structure under the checkpoint_dir provided to ``_save`` is preserved. See the example below. .. code-block:: python class Example(Trainable): def _save(self, checkpoint_path): print(checkpoint_path) return os.path.join(checkpoint_path, "my/check/point") def _restore(self, checkpoint): print(checkpoint) >>> trainer = Example() >>> obj = trainer.save_to_object() # This is used when PAUSED. <logdir>/tmpc8k_c_6hsave_to_object/checkpoint_0/my/check/point >>> trainer.restore_from_object(obj) # Note the different prefix. <logdir>/tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point Args: checkpoint (str|dict): If dict, the return value is as returned by `_save`. If a string, then it is a checkpoint path that may have a different prefix than that returned by `_save`. The directory structure underneath the `checkpoint_dir` `_save` is preserved. """ raise NotImplementedError def _setup(self, config): """Subclasses should override this for custom initialization. Args: config (dict): Hyperparameters and other configs given. Copy of `self.config`. """ pass def _log_result(self, result): """Subclasses can optionally override this to customize logging. Args: result (dict): Training result returned by _train(). """ self._result_logger.on_result(result) def _stop(self): """Subclasses should override this for any cleanup on stop. If any Ray actors are launched in the Trainable (i.e., with a RLlib trainer), be sure to kill the Ray actor process here. You can kill a Ray actor by calling `actor.__ray_terminate__.remote()` on the actor. """ pass def _export_model(self, export_formats, export_dir): """Subclasses should override this to export model. Args: export_formats (list): List of formats that should be exported. export_dir (str): Directory to place exported models. Return: A dict that maps ExportFormats to successfully exported models. """ return {}