def load_snapshot_file(trainer, fpath): fpath = io.assert_extension(fpath, __snapshot_ext__) snapshot = io.load(fpath) if snapshot is None: return False trainer.load_snapshot(snapshot) return True
def load_weights_file(env, fpath): weights = io.load(fpath) if weights is None: return False if fpath.endswith(__snapshot_ext__): weights = weights['variables'] env.network.assign_all_variables_dict(weights) return True
def post_preference(self, uid, pref): dirname = _compose_dir(uid) pair = io.load(osp.join(dirname, 'pair.pkl')) io.dump(osp.join(dirname, 'pref.txt'), str(pref)) logger.info('Post preference uid={}, pref={}.'.format(uid, pref)) data = TrainingData(pair.t1_state, pair.t1_action, pair.t2_state, pair.t2_action, pref) # Lock acquired inside this function call. self._rpredictor.add_training_data(data)
def load_weights_file(env, fpath, trainer=None): weights = io.load(fpath) if weights is None: return False if fpath.endswith(__snapshot_ext__): weights = weights['variables'] if trainer is not None: trainer.trigger_event('plugin:weights:load', weights) env.network.assign_all_variables_dict(weights) return True
def __restore_preferences(self): """Restore the preferences we already have.""" dirname = osp.join(get_env('dir.root'), 'trajectories') if not osp.isdir(dirname): return all_data = [] logger.critical('Restoring preference') for uid in os.listdir(dirname): item = osp.join(dirname, uid) pref_filename = osp.join(item, 'pref.txt') pair_filename = osp.join(item, 'pair.pkl') if osp.exists(pref_filename) and osp.exists(pair_filename): pref = float(io.load(pref_filename)[0]) pair = io.load(pair_filename) data = TrainingData(pair.t1_state, pair.t1_action, pair.t2_state, pair.t2_action, pref) all_data.append(data) if len(all_data) > 0: self._rpredictor.extend_training_data(all_data) logger.critical('Preference restore finished: success={}'.format(len(all_data)))