def dump_weights_file(env, fpath, trainer=None): fpath = io.assert_extension(fpath, __weights_ext__) weights = env.network.fetch_all_variables_dict() if trainer is not None: trainer.trigger_event('plugin:weights:dump', weights) io.dump(fpath, weights) return fpath
def dump_snapshot_on_epoch_after(trainer): if trainer.epoch % save_interval != 0: return snapshot_dir = get_snapshot_dir() snapshot = trainer.dump_snapshot() fpath = osp.join(snapshot_dir, 'epoch_{}'.format(trainer.epoch) + __snapshot_ext__) io.mkdir(osp.dirname(fpath)) io.dump(fpath, snapshot) fpath_aliased = [] fpath_last = osp.join(snapshot_dir, 'last_epoch' + __snapshot_ext__) fpath_aliased.append(fpath_last) best_loss = trainer.runtime.get('best_loss', np.inf) current_loss = trainer.runtime.get('loss', None) if current_loss is not None and best_loss > current_loss: trainer.runtime['best_loss'] = current_loss fpath_best_loss = osp.join(snapshot_dir, 'best_loss' + __snapshot_ext__) fpath_aliased.append(fpath_best_loss) best_error = trainer.runtime.get('best_error', np.inf) current_error = trainer.runtime.get('error', None) if current_error is not None and best_error > current_error: trainer.runtime['best_error'] = current_error fpath_best_error = osp.join(snapshot_dir, 'best_error' + __snapshot_ext__) fpath_aliased.append(fpath_best_error) io.link(fpath, *fpath_aliased) logger.info('Model at epoch {} dumped to {}.\n(Also alias: {}).'.format( trainer.epoch, fpath, ', '.join(fpath_aliased)))
def post_preference(self, uid, pref): dirname = _compose_dir(uid) pair = io.load(osp.join(dirname, 'pair.pkl')) io.dump(osp.join(dirname, 'pref.txt'), str(pref)) logger.info('Post preference uid={}, pref={}.'.format(uid, pref)) data = TrainingData(pair.t1_state, pair.t1_action, pair.t2_state, pair.t2_action, pref) # Lock acquired inside this function call. self._rpredictor.add_training_data(data)
def _process(self, pair): uid = uuid.uuid4().hex dirname = _compose_dir(uid) io.mkdir(dirname) # dump the file for displaying if self._tformat == 'GIF': _save_gif(pair.t1_observation, osp.join(dirname, '1.gif')) _save_gif(pair.t2_observation, osp.join(dirname, '2.gif')) else: raise ValueError('Unknown trajectory format: {}'.format(self._tformat)) # cleanup pair = TrajectoryPair(pair.t1_state, None, pair.t1_action, pair.t2_state, None, pair.t2_action) # dump the raw pair io.dump(osp.join(dirname, 'pair.pkl'), pair) return uid, pair
def dump_weights_file(env, fpath): fpath = io.assert_extension(fpath, __weights_ext__) weights = env.network.fetch_all_variables_dict() io.dump(fpath, weights) return fpath