def run_wandb(entity, project, run_id, run_cls: type = Training, checkpoint_path: str = None): """run and save config and stats to https://wandb.com""" wandb_dir = mkdtemp() # prevent wandb from polluting the home directory atexit.register( shutil.rmtree, wandb_dir, ignore_errors=True) # clean up after wandb atexit handler finishes import wandb config = partial_to_dict(run_cls) config['seed'] = config['seed'] or randrange( 1, 1000000) # if seed == 0 replace with random config['environ'] = log_environment_variables() config['git'] = git_info() resume = checkpoint_path and exists(checkpoint_path) wandb.init(dir=wandb_dir, entity=entity, project=project, id=run_id, resume=resume, config=config) for stats in iterate_episodes(run_cls, checkpoint_path): [wandb.log(json.loads(s.to_json())) for s in stats]
def run_fs(path: str, run_cls: type = Training): """run and save config and stats to `path` (with pickle)""" if not exists(path): os.mkdir(path) save_json(partial_to_dict(run_cls), path + '/spec.json') if not exists(path + '/stats'): dump(pd.DataFrame(), path + '/stats') for stats in iterate_episodes(run_cls, path + '/state'): dump( load(path + '/stats').append(stats, ignore_index=True), path + '/stats') # concat with stats from previous episodes
def iterate_episodes(run_cls: type = Training, checkpoint_path: str = None): """Generator [1] yielding episode statistics (list of pd.Series) while running and checkpointing - run_cls: can by any callable that outputs an appropriate run object (e.g. has a 'run_epoch' method) [1] https://docs.python.org/3/howto/functional.html#generators """ checkpoint_path = checkpoint_path or tempfile.mktemp("_remove_on_exit") try: if not exists(checkpoint_path): print("=== specification ".ljust(70, "=")) print(yaml.dump(partial_to_dict(run_cls), indent=3, default_flow_style=False, sort_keys=False), end="") run_instance = run_cls() dump(run_instance, checkpoint_path) print("") else: print("\ncontinuing...\n") run_instance = load(checkpoint_path) while run_instance.epoch < run_instance.epochs: # time.sleep(1) # on network file systems writing files is asynchronous and we need to wait for sync yield run_instance.run_epoch( ) # yield stats data frame (this makes this function a generator) print("") dump(run_instance, checkpoint_path) # we delete and reload the run_instance from disk to ensure the exact same code runs regardless of interruptions del run_instance gc.collect() run_instance = load(checkpoint_path) finally: if checkpoint_path.endswith("_remove_on_exit") and exists( checkpoint_path): os.remove(checkpoint_path)