def logger_context(log_dir, run_ID, name, log_params=None, snapshot_mode="none"): logger.set_snapshot_mode(snapshot_mode) logger.set_log_tabular_only(False) exp_dir = get_log_dir(log_dir, run_ID) tabular_log_file = osp.join(exp_dir, "progress.csv") text_log_file = osp.join(exp_dir, "debug.log") params_log_file = osp.join(exp_dir, "params.json") logger.set_snapshot_dir(exp_dir) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) logger.push_prefix(f"{name}_{run_ID} ") if log_params is None: log_params = dict() log_params["name"] = name log_params["run_ID"] = run_ID with open(params_log_file, "w") as f: json.dump(log_params, f) yield logger.remove_tabular_output(tabular_log_file) logger.remove_text_output(text_log_file) logger.pop_prefix()
def logger_context(log_dir, run_ID, name, log_params=None, snapshot_mode="none"): logger.set_snapshot_mode(snapshot_mode) logger.set_log_tabular_only(False) log_dir = osp.join(log_dir, f"run_{run_ID}") exp_dir = osp.abspath(log_dir) if LOG_DIR != osp.commonpath([exp_dir, LOG_DIR]): print(f"logger_context received log_dir outside of {LOG_DIR}: " f"prepending by {LOG_DIR}/local/<yyyymmdd>/") exp_dir = get_log_dir(log_dir) tabular_log_file = osp.join(exp_dir, "progress.csv") text_log_file = osp.join(exp_dir, "debug.log") params_log_file = osp.join(exp_dir, "params.json") logger.set_snapshot_dir(exp_dir) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) logger.push_prefix(f"{name}_{run_ID} ") if log_params is None: log_params = dict() log_params["name"] = name log_params["run_ID"] = run_ID with open(params_log_file, "w") as f: json.dump(log_params, f) yield logger.remove_tabular_output(tabular_log_file) logger.remove_text_output(text_log_file) logger.pop_prefix()
def logger_context( log_dir, run_ID, name, log_params=None, snapshot_mode="none", override_prefix=False, use_summary_writer=False, ): """Use as context manager around calls to the runner's ``train()`` method. Sets up the logger directory and filenames. Unless override_prefix is True, this function automatically prepends ``log_dir`` with the rlpyt logging directory and the date: `path-to-rlpyt/data/yyyymmdd/hhmmss` (`data/` is in the gitignore), and appends with `/run_{run_ID}` to separate multiple runs of the same settings. Saves hyperparameters provided in ``log_params`` to `params.json`, along with experiment `name` and `run_ID`. Input ``snapshot_mode`` refers to how often the logger actually saves the snapshot (e.g. may include agent parameters). The runner calls on the logger to save the snapshot at every iteration, but the input ``snapshot_mode`` sets how often the logger actually saves (e.g. snapshot may include agent parameters). Possible modes include (but check inside the logger itself): * "none": don't save at all * "last": always save and overwrite the previous * "all": always save and keep each iteration * "gap": save periodically and keep each (will also need to set the gap, not done here) The cleanup operations after the ``yield`` close files but might not be strictly necessary if not launching another training session in the same python process. """ logger.set_snapshot_mode(snapshot_mode) logger.set_log_tabular_only(False) log_dir = osp.join(log_dir, f"run_{run_ID}") exp_dir = osp.abspath(log_dir) if LOG_DIR != osp.commonpath([exp_dir, LOG_DIR]) and not override_prefix: print(f"logger_context received log_dir outside of {LOG_DIR}: " f"prepending by {LOG_DIR}/local/<yyyymmdd>/<hhmmss>/") exp_dir = get_log_dir(log_dir) tabular_log_file = osp.join(exp_dir, "progress.csv") text_log_file = osp.join(exp_dir, "debug.log") params_log_file = osp.join(exp_dir, "params.json") logger.set_snapshot_dir(exp_dir) if use_summary_writer: logger.set_tf_summary_writer(SummaryWriter(exp_dir)) logger.add_text_output(text_log_file) logger.add_tabular_output(tabular_log_file) logger.push_prefix(f"{name}_{run_ID} ") if log_params is None: log_params = dict() log_params["name"] = name log_params["run_ID"] = run_ID with open(params_log_file, "w") as f: json.dump(log_params, f, default=lambda o: type(o).__name__) yield logger.remove_tabular_output(tabular_log_file) logger.remove_text_output(text_log_file) logger.pop_prefix()