示例#1
0
文件: __init__.py 项目: wsg1873/rtrl
def run_wandb(entity,
              project,
              run_id,
              run_cls: type = Training,
              checkpoint_path: str = None):
    """run and save config and stats to https://wandb.com"""
    wandb_dir = mkdtemp()  # prevent wandb from polluting the home directory
    atexit.register(
        shutil.rmtree, wandb_dir,
        ignore_errors=True)  # clean up after wandb atexit handler finishes
    import wandb
    config = partial_to_dict(run_cls)
    config['seed'] = config['seed'] or randrange(
        1, 1000000)  # if seed == 0 replace with random
    config['environ'] = log_environment_variables()
    config['git'] = git_info()
    resume = checkpoint_path and exists(checkpoint_path)
    wandb.init(dir=wandb_dir,
               entity=entity,
               project=project,
               id=run_id,
               resume=resume,
               config=config)
    for stats in iterate_episodes(run_cls, checkpoint_path):
        [wandb.log(json.loads(s.to_json())) for s in stats]
示例#2
0
文件: __init__.py 项目: wsg1873/rtrl
def run_fs(path: str, run_cls: type = Training):
    """run and save config and stats to `path` (with pickle)"""
    if not exists(path):
        os.mkdir(path)
    save_json(partial_to_dict(run_cls), path + '/spec.json')
    if not exists(path + '/stats'):
        dump(pd.DataFrame(), path + '/stats')
    for stats in iterate_episodes(run_cls, path + '/state'):
        dump(
            load(path + '/stats').append(stats, ignore_index=True),
            path + '/stats')  # concat with stats from previous episodes
示例#3
0
文件: __init__.py 项目: wsg1873/rtrl
def iterate_episodes(run_cls: type = Training, checkpoint_path: str = None):
    """Generator [1] yielding episode statistics (list of pd.Series) while running and checkpointing
  - run_cls: can by any callable that outputs an appropriate run object (e.g. has a 'run_epoch' method)

  [1] https://docs.python.org/3/howto/functional.html#generators
  """
    checkpoint_path = checkpoint_path or tempfile.mktemp("_remove_on_exit")

    try:
        if not exists(checkpoint_path):
            print("=== specification ".ljust(70, "="))
            print(yaml.dump(partial_to_dict(run_cls),
                            indent=3,
                            default_flow_style=False,
                            sort_keys=False),
                  end="")
            run_instance = run_cls()
            dump(run_instance, checkpoint_path)
            print("")
        else:
            print("\ncontinuing...\n")

        run_instance = load(checkpoint_path)
        while run_instance.epoch < run_instance.epochs:
            # time.sleep(1)  # on network file systems writing files is asynchronous and we need to wait for sync
            yield run_instance.run_epoch(
            )  # yield stats data frame (this makes this function a generator)
            print("")
            dump(run_instance, checkpoint_path)

            # we delete and reload the run_instance from disk to ensure the exact same code runs regardless of interruptions
            del run_instance
            gc.collect()
            run_instance = load(checkpoint_path)

    finally:
        if checkpoint_path.endswith("_remove_on_exit") and exists(
                checkpoint_path):
            os.remove(checkpoint_path)