def setup_logger(config: dict, directory, tag: str): logger = logging.getLogger() if not mpi.is_main_proc() and not config['all_ranks']: # Set level to a something higher than logging.CRITICAL to silence all messages logger.setLevel(logging.CRITICAL + 1) else: logger.setLevel(config['log_level']) name = '' if mpi.get_num_procs() > 1: name = f'rank[{mpi.get_proc_rank()}] ' formatter = logging.Formatter('%(asctime)s.%(msecs)03d ' + name + '%(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') ch = logging.StreamHandler(stream=sys.stdout) ch.setFormatter(formatter) logger.addHandler(ch) path = os.path.join(directory, tag + '.log') fh = mpi.MPIFileHandler(path) fh.setFormatter(formatter) logger.addHandler(fh)
def save(self, obj: object, name: str): if not mpi.is_main_proc(): return path = os.path.join(self.directory, self.tag + '_' + name + self._suffix) logging.debug(f'Saving info: {path}') with open(path, mode='a') as f: f.write(json.dumps(obj)) f.write('\n')
def save(self, obj: object, num_steps: int, info: str): if not self.all_ranks and not mpi.is_main_proc(): return added = f'steps-{num_steps}_rank-{mpi.get_proc_rank()}' path = os.path.join(self.directory, self.tag + '_' + added + '_' + info + self._suffix) logging.debug(f'Saving rollout: {path}') with open(path, mode='wb') as f: pickle.dump(obj, f)
def save_config(config: dict, directory: str, tag: str, verbose=True): if not mpi.is_main_proc(): return formatted = json.dumps(config, indent=4, sort_keys=True) if verbose: logging.info(formatted) path = os.path.join(directory, tag + '.json') with open(file=path, mode='w') as f: f.write(formatted)
def save(self, module: AbstractActorCritic, num_steps: int): if not mpi.is_main_proc(): return # Save model model_path = self._get_model_path() logging.debug(f'Saving model: {model_path}') torch.save(obj=module, f=model_path) # Save iteration info_path = self._get_info_path() with open(info_path, mode='w') as f: f.write(str(num_steps))