Example #1
0
def setup_logger(config: dict, directory, tag: str):
    logger = logging.getLogger()

    if not mpi.is_main_proc() and not config['all_ranks']:
        # Set level to a something higher than logging.CRITICAL to silence all messages
        logger.setLevel(logging.CRITICAL + 1)
    else:
        logger.setLevel(config['log_level'])

    name = ''
    if mpi.get_num_procs() > 1:
        name = f'rank[{mpi.get_proc_rank()}] '

    formatter = logging.Formatter('%(asctime)s.%(msecs)03d ' + name + '%(levelname)s: %(message)s',
                                  datefmt='%Y-%m-%d %H:%M:%S')

    ch = logging.StreamHandler(stream=sys.stdout)
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    path = os.path.join(directory, tag + '.log')
    fh = mpi.MPIFileHandler(path)
    fh.setFormatter(formatter)

    logger.addHandler(fh)
Example #2
0
    def save(self, obj: object, name: str):
        if not mpi.is_main_proc():
            return

        path = os.path.join(self.directory, self.tag + '_' + name + self._suffix)
        logging.debug(f'Saving info: {path}')
        with open(path, mode='a') as f:
            f.write(json.dumps(obj))
            f.write('\n')
Example #3
0
    def save(self, obj: object, num_steps: int, info: str):
        if not self.all_ranks and not mpi.is_main_proc():
            return

        added = f'steps-{num_steps}_rank-{mpi.get_proc_rank()}'

        path = os.path.join(self.directory, self.tag + '_' + added + '_' + info + self._suffix)
        logging.debug(f'Saving rollout: {path}')
        with open(path, mode='wb') as f:
            pickle.dump(obj, f)
Example #4
0
def save_config(config: dict, directory: str, tag: str, verbose=True):
    if not mpi.is_main_proc():
        return

    formatted = json.dumps(config, indent=4, sort_keys=True)

    if verbose:
        logging.info(formatted)

    path = os.path.join(directory, tag + '.json')
    with open(file=path, mode='w') as f:
        f.write(formatted)
Example #5
0
    def save(self, module: AbstractActorCritic, num_steps: int):
        if not mpi.is_main_proc():
            return

        # Save model
        model_path = self._get_model_path()
        logging.debug(f'Saving model: {model_path}')
        torch.save(obj=module, f=model_path)

        # Save iteration
        info_path = self._get_info_path()
        with open(info_path, mode='w') as f:
            f.write(str(num_steps))