def testCustomLoggerWithAutoLogging(self): """Creates CSV/JSON logger callbacks automatically""" if "TUNE_DISABLE_AUTO_CALLBACK_LOGGERS" in os.environ: del os.environ["TUNE_DISABLE_AUTO_CALLBACK_LOGGERS"] class CustomLogger(Logger): def on_result(self, result): with open(os.path.join(self.logdir, "test.log"), "w") as f: f.write("hi") [trial] = run_experiments( {"foo": {"run": "__fake", "stop": {"training_iteration": 1}}}, callbacks=[LegacyLoggerCallback(logger_classes=[CustomLogger])], ) self.assertTrue(os.path.exists(os.path.join(trial.logdir, "test.log"))) self.assertTrue(os.path.exists(os.path.join(trial.logdir, "params.json"))) [trial] = run_experiments( {"foo": {"run": "__fake", "stop": {"training_iteration": 1}}} ) self.assertTrue(os.path.exists(os.path.join(trial.logdir, "params.json"))) [trial] = run_experiments( {"foo": {"run": "__fake", "stop": {"training_iteration": 1}}}, callbacks=[LegacyLoggerCallback(logger_classes=[])], ) self.assertTrue(os.path.exists(os.path.join(trial.logdir, "params.json")))
def testCallbackReordering(self): """SyncerCallback should come after LoggerCallback callbacks""" def get_positions(callbacks): first_logger_pos = None last_logger_pos = None syncer_pos = None for i, callback in enumerate(callbacks): if isinstance(callback, LoggerCallback): if first_logger_pos is None: first_logger_pos = i last_logger_pos = i elif isinstance(callback, SyncerCallback): syncer_pos = i return first_logger_pos, last_logger_pos, syncer_pos # Auto creation of loggers, no callbacks, no syncer callbacks = create_default_callbacks(None, SyncConfig(), None) first_logger_pos, last_logger_pos, syncer_pos = get_positions( callbacks) self.assertLess(last_logger_pos, syncer_pos) # Auto creation of loggers with callbacks callbacks = create_default_callbacks([Callback()], SyncConfig(), None) first_logger_pos, last_logger_pos, syncer_pos = get_positions( callbacks) self.assertLess(last_logger_pos, syncer_pos) # Auto creation of loggers with existing logger (but no CSV/JSON) callbacks = create_default_callbacks([LoggerCallback()], SyncConfig(), None) first_logger_pos, last_logger_pos, syncer_pos = get_positions( callbacks) self.assertLess(last_logger_pos, syncer_pos) # This should throw an error as the syncer comes before the logger with self.assertRaises(ValueError): callbacks = create_default_callbacks( [SyncerCallback(None), LoggerCallback()], SyncConfig(), None) # This should be reordered but preserve the regular callback order [mc1, mc2, mc3] = [Callback(), Callback(), Callback()] # Has to be legacy logger to avoid logger callback creation lc = LegacyLoggerCallback(logger_classes=DEFAULT_LOGGERS) callbacks = create_default_callbacks([mc1, mc2, lc, mc3], SyncConfig(), None) print(callbacks) first_logger_pos, last_logger_pos, syncer_pos = get_positions( callbacks) self.assertLess(last_logger_pos, syncer_pos) self.assertLess(callbacks.index(mc1), callbacks.index(mc2)) self.assertLess(callbacks.index(mc2), callbacks.index(mc3)) self.assertLess(callbacks.index(lc), callbacks.index(mc3)) # Syncer callback is appended self.assertLess(callbacks.index(mc3), syncer_pos)
def create_default_callbacks(callbacks: Optional[List[Callback]], sync_config: SyncConfig, loggers: Optional[List[Logger]], metric: Optional[str] = None): """Create default callbacks for `tune.run()`. This function takes a list of existing callbacks and adds default callbacks to it. Specifically, three kinds of callbacks will be added: 1. Loggers. Ray Tune's experiment analysis relies on CSV and JSON logging. 2. Syncer. Ray Tune synchronizes logs and checkpoint between workers and the head node. 2. Trial progress reporter. For reporting intermediate progress, like trial results, Ray Tune uses a callback. These callbacks will only be added if they don't already exist, i.e. if they haven't been passed (and configured) by the user. A notable case is when a Logger is passed, which is not a CSV or JSON logger - then a CSV and JSON logger will still be created. Lastly, this function will ensure that the Syncer callback comes after all Logger callbacks, to ensure that the most up-to-date logs and checkpoints are synced across nodes. """ callbacks = callbacks or [] has_syncer_callback = False has_csv_logger = False has_json_logger = False has_tbx_logger = False has_trial_progress_callback = any( isinstance(c, TrialProgressCallback) for c in callbacks) if not has_trial_progress_callback: trial_progress_callback = TrialProgressCallback(metric=metric) callbacks.append(trial_progress_callback) # Track syncer obj/index to move callback after loggers last_logger_index = None syncer_index = None # Create LegacyLoggerCallback for passed Logger classes if loggers: # Todo(krfricke): Deprecate `loggers` argument, print warning here. # Add warning as soon as we ported all loggers to LoggerCallback # classes. add_loggers = [] for trial_logger in loggers: if isinstance(trial_logger, LoggerCallback): callbacks.append(trial_logger) elif isinstance(trial_logger, type) and issubclass( trial_logger, Logger): add_loggers.append(trial_logger) else: raise ValueError( f"Invalid value passed to `loggers` argument of " f"`tune.run()`: {trial_logger}") if add_loggers: callbacks.append(LegacyLoggerCallback(add_loggers)) # Check if we have a CSV, JSON and TensorboardX logger for i, callback in enumerate(callbacks): if isinstance(callback, LegacyLoggerCallback): last_logger_index = i if CSVLogger in callback.logger_classes: has_csv_logger = True if JsonLogger in callback.logger_classes: has_json_logger = True if TBXLogger in callback.logger_classes: has_tbx_logger = True elif isinstance(callback, CSVLoggerCallback): has_csv_logger = True last_logger_index = i elif isinstance(callback, JsonLoggerCallback): has_json_logger = True last_logger_index = i elif isinstance(callback, TBXLoggerCallback): has_tbx_logger = True last_logger_index = i elif isinstance(callback, SyncerCallback): syncer_index = i has_syncer_callback = True # If CSV, JSON or TensorboardX loggers are missing, add if os.environ.get("TUNE_DISABLE_AUTO_CALLBACK_LOGGERS", "0") != "1": if not has_csv_logger: callbacks.append(CSVLoggerCallback()) last_logger_index = len(callbacks) - 1 if not has_json_logger: callbacks.append(JsonLoggerCallback()) last_logger_index = len(callbacks) - 1 if not has_tbx_logger: try: callbacks.append(TBXLoggerCallback()) last_logger_index = len(callbacks) - 1 except ImportError: logger.warning( "The TensorboardX logger cannot be instantiated because " "either TensorboardX or one of it's dependencies is not " "installed. Please make sure you have the latest version " "of TensorboardX installed: `pip install -U tensorboardx`") # If no SyncerCallback was found, add if not has_syncer_callback and os.environ.get( "TUNE_DISABLE_AUTO_CALLBACK_SYNCER", "0") != "1": # Detect Docker and Kubernetes environments _sync_to_driver = detect_sync_to_driver(sync_config.sync_to_driver) syncer_callback = SyncerCallback(sync_function=_sync_to_driver) callbacks.append(syncer_callback) syncer_index = len(callbacks) - 1 if syncer_index is not None and last_logger_index is not None and \ syncer_index < last_logger_index: if (not has_csv_logger or not has_json_logger or not has_tbx_logger) \ and not loggers: # Only raise the warning if the loggers were passed by the user. # (I.e. don't warn if this was automatic behavior and they only # passed a customer SyncerCallback). raise ValueError( "The `SyncerCallback` you passed to `tune.run()` came before " "at least one `LoggerCallback`. Syncing should be done " "after writing logs. Please re-order the callbacks so that " "the `SyncerCallback` comes after any `LoggerCallback`.") else: # If these loggers were automatically created. just re-order # the callbacks syncer_obj = callbacks[syncer_index] callbacks.pop(syncer_index) callbacks.insert(last_logger_index, syncer_obj) return callbacks
def create_default_callbacks(callbacks: Optional[List[Callback]], sync_config: SyncConfig, loggers: Optional[List[Logger]]): callbacks = callbacks or [] has_syncer_callback = False has_csv_logger = False has_json_logger = False has_tbx_logger = False # Track syncer obj/index to move callback after loggers last_logger_index = None syncer_index = None # Create LegacyLoggerCallback for passed Logger classes if loggers: # Todo(krfricke): Deprecate `loggers` argument, print warning here. # Add warning as soon as we ported all loggers to LoggerCallback # classes. add_loggers = [] for trial_logger in loggers: if isinstance(trial_logger, LoggerCallback): callbacks.append(trial_logger) elif isinstance(trial_logger, type) and issubclass( trial_logger, Logger): add_loggers.append(trial_logger) else: raise ValueError( f"Invalid value passed to `loggers` argument of " f"`tune.run()`: {trial_logger}") if add_loggers: callbacks.append(LegacyLoggerCallback(add_loggers)) # Check if we have a CSV, JSON and TensorboardX logger for i, callback in enumerate(callbacks): if isinstance(callback, LegacyLoggerCallback): last_logger_index = i if CSVLogger in callback.logger_classes: has_csv_logger = True if JsonLogger in callback.logger_classes: has_json_logger = True if TBXLogger in callback.logger_classes: has_tbx_logger = True elif isinstance(callback, CSVLoggerCallback): has_csv_logger = True last_logger_index = i elif isinstance(callback, JsonLoggerCallback): has_json_logger = True last_logger_index = i elif isinstance(callback, TBXLoggerCallback): has_tbx_logger = True last_logger_index = i elif isinstance(callback, SyncerCallback): syncer_index = i has_syncer_callback = True # If CSV, JSON or TensorboardX loggers are missing, add if os.environ.get("TUNE_DISABLE_AUTO_CALLBACK_LOGGERS", "0") != "1": if not has_csv_logger: callbacks.append(CSVLoggerCallback()) last_logger_index = len(callbacks) - 1 if not has_json_logger: callbacks.append(JsonLoggerCallback()) last_logger_index = len(callbacks) - 1 if not has_tbx_logger: callbacks.append(TBXLoggerCallback()) last_logger_index = len(callbacks) - 1 # If no SyncerCallback was found, add if not has_syncer_callback and os.environ.get( "TUNE_DISABLE_AUTO_CALLBACK_SYNCER", "0") != "1": # Detect Docker and Kubernetes environments _sync_to_driver = detect_sync_to_driver(sync_config.sync_to_driver) syncer_callback = SyncerCallback(sync_function=_sync_to_driver) callbacks.append(syncer_callback) syncer_index = len(callbacks) - 1 if syncer_index is not None and last_logger_index is not None and \ syncer_index < last_logger_index: if (not has_csv_logger or not has_json_logger or not has_tbx_logger) \ and not loggers: # Only raise the warning if the loggers were passed by the user. # (I.e. don't warn if this was automatic behavior and they only # passed a customer SyncerCallback). raise ValueError( "The `SyncerCallback` you passed to `tune.run()` came before " "at least one `LoggerCallback`. Syncing should be done " "after writing logs. Please re-order the callbacks so that " "the `SyncerCallback` comes after any `LoggerCallback`.") else: # If these loggers were automatically created. just re-order # the callbacks syncer_obj = callbacks[syncer_index] callbacks.pop(syncer_index) callbacks.insert(last_logger_index, syncer_obj) return callbacks