def dump_snapshot_on_epoch_after(trainer): if trainer.epoch % save_interval != 0: return snapshot_dir = get_snapshot_dir() snapshot = trainer.dump_snapshot() fpath = osp.join(snapshot_dir, 'epoch_{}'.format(trainer.epoch) + __snapshot_ext__) io.mkdir(osp.dirname(fpath)) io.dump(fpath, snapshot) fpath_aliased = [] fpath_last = osp.join(snapshot_dir, 'last_epoch' + __snapshot_ext__) fpath_aliased.append(fpath_last) best_loss = trainer.runtime.get('best_loss', np.inf) current_loss = trainer.runtime.get('loss', None) if current_loss is not None and best_loss > current_loss: trainer.runtime['best_loss'] = current_loss fpath_best_loss = osp.join(snapshot_dir, 'best_loss' + __snapshot_ext__) fpath_aliased.append(fpath_best_loss) best_error = trainer.runtime.get('best_error', np.inf) current_error = trainer.runtime.get('error', None) if current_error is not None and best_error > current_error: trainer.runtime['best_error'] = current_error fpath_best_error = osp.join(snapshot_dir, 'best_error' + __snapshot_ext__) fpath_aliased.append(fpath_best_error) io.link(fpath, *fpath_aliased) logger.info('Model at epoch {} dumped to {}.\n(Also alias: {}).'.format( trainer.epoch, fpath, ', '.join(fpath_aliased)))
def dump_pack(cfg, pack): pickleable = pack.make_pickleable() name = cfg.name + '-' name += time.strftime('%Y%m%d-%H%M%S') name += '.replay.pkl' io.mkdir('replays') with open(osp.join('replays', name), 'wb') as f: pickle.dump(pickleable, f, pickle.HIGHEST_PROTOCOL) logger.critical('replay written to replays/{}'.format(name))
def __init__(self, name, dump_dir=None, force_dump=False, state_mode='DEFAULT'): super().__init__() with get_env_lock(): self._gym = gym.make(name) if dump_dir: io.mkdir(dump_dir) self._gym = gym.wrappers.Monitor(self._gym, dump_dir, force=force_dump) assert state_mode in ('DEFAULT', 'RENDER', 'BOTH') self._state_mode = state_mode
def _process(self, pair): uid = uuid.uuid4().hex dirname = _compose_dir(uid) io.mkdir(dirname) # dump the file for displaying if self._tformat == 'GIF': _save_gif(pair.t1_observation, osp.join(dirname, '1.gif')) _save_gif(pair.t2_observation, osp.join(dirname, '2.gif')) else: raise ValueError('Unknown trajectory format: {}'.format(self._tformat)) # cleanup pair = TrajectoryPair(pair.t1_state, None, pair.t1_action, pair.t2_state, None, pair.t2_action) # dump the raw pair io.dump(osp.join(dirname, 'pair.pkl'), pair) return uid, pair
def tensorboard_summary_enable(trainer, tb_path=tensorboard_path): if tb_path is None: tb_path = osp.join(get_env('dir.root'), 'tensorboard') restored = 'restore_snapshot' in trainer.runtime if osp.exists(tb_path) and not restored: logger.warn('Removing old tensorboard directory: {}.'.format(tb_path)) shutil.rmtree(tb_path) io.mkdir(tb_path) trainer.runtime['tensorboard_summary_path'] = tb_path trainer._tensorboard_writer = tf.summary.FileWriter(tb_path, graph=trainer.env.graph) if enable_tensorboard_web: port = random.randrange(49152, 65536.) port = trainer.runtime.get('tensorboard_web_port', port) trainer._tensorboard_webserver = threading.Thread( target=_tensorboard_webserver_thread, args=['tensorboard', '--logdir', tb_path, '--port', str(port)], daemon=True) trainer._tensorboard_webserver.start() trainer.runtime['tensorboard_web_port'] = port