def __init__(self, output_dir=None, output_fname='progress.txt', exp_name=None): """ Initialize a Logger. Args: output_dir (string): A directory for saving results to. If ``None``, defaults to a temp directory of the form ``/tmp/experiments/somerandomnumber``. output_fname (string): Name for the tab-separated-value file containing metrics logged throughout a training run. Defaults to ``progress.txt``. exp_name (string): Experiment name. If you run multiple training runs and give them all the same ``exp_name``, the plotter will know to group them. (Use case: if you run the same hyperparameter configuration with multiple random seeds, you should give them all the same ``exp_name``.) """ if proc_id()==0: self.output_dir = output_dir or "/tmp/experiments/%i"%int(time.time()) if osp.exists(self.output_dir): print("Warning: Log dir %s already exists! Storing info there anyway."%self.output_dir) else: os.makedirs(self.output_dir) self.output_file = open(osp.join(self.output_dir, output_fname), 'w') atexit.register(self.output_file.close) print(colorize("Logging data to %s"%self.output_file.name, "green", bold=True)) else: self.output_dir = None self.output_file = None self.first_row=True self.log_headers = [] self.log_current_row = {} self.exp_name = exp_name
def dump_tabular(self): """ Write all of the diagnostics from the current iteration. Writes both to stdout, and to the output file. """ if proc_id()==0: vals = [] key_lens = [len(key) for key in self.log_headers] max_key_len = max(15,max(key_lens)) keystr = '%'+'%d'%max_key_len fmt = "| " + keystr + "s | %15s |" n_slashes = 22 + max_key_len print("-"*n_slashes) for key in self.log_headers: val = self.log_current_row.get(key, "") valstr = "%8.3g"%val if hasattr(val, "__float__") else val print(fmt%(key, valstr)) vals.append(val) print("-"*n_slashes, flush=True) if self.output_file is not None: if self.first_row: self.output_file.write("\t".join(self.log_headers)+"\n") self.output_file.write("\t".join(map(str,vals))+"\n") self.output_file.flush() self.log_current_row.clear() self.first_row=False
def save_config(self, config): """ Log an experiment configuration. Call this once at the top of your experiment, passing in all important config vars as a dict. This will serialize the config to JSON, while handling anything which can't be serialized in a graceful way (writing as informative a string as possible). Example use: .. code-block:: python logger = EpochLogger(**logger_kwargs) logger.save_config(locals()) """ config_json = convert_json(config) if self.exp_name is not None: config_json['exp_name'] = self.exp_name if proc_id()==0: output = json.dumps(config_json, separators=(',',':\t'), indent=4, sort_keys=True) print(colorize('Saving config:\n', color='cyan', bold=True)) print(output) with open(osp.join(self.output_dir, "config.json"), 'w') as out: out.write(output)
def _pytorch_simple_save(self, itr=None): """ Saves the PyTorch model (or models). """ if proc_id()==0: assert hasattr(self, "pytorch_saver_elements"), \ "First have to setup saving with self.setup_pytorch_saver" fpath = "pyt_save" fpath = osp.join(self.output_dir, fpath) fname = "model" + ("%d"%itr if itr is not None else "") + ".pt" fname = osp.join(fpath, fname) os.makedirs(fpath, exist_ok=True) with warnings.catch_warnings(): warnings.simplefilter("ignore") # We are using a non-recommended way of saving PyTorch models, # by pickling whole objects (which are dependent on the exact # directory structure at the time of saving) as opposed to # just saving network weights. This works sufficiently well # for the purposes of Spinning Up, but you may want to do # something different for your personal PyTorch project. # We use a catch_warnings() context to avoid the warnings about # not being able to save the source code. torch.save(self.pytorch_saver_elements, fname)
def log(self, msg, color='green'): """Print a colorized message to stdout.""" if proc_id()==0: print(colorize(msg, color, bold=True))