示例#1
0
def dump_weights_file(env, fpath, trainer=None):
    fpath = io.assert_extension(fpath, __weights_ext__)
    weights = env.network.fetch_all_variables_dict()
    if trainer is not None:
        trainer.trigger_event('plugin:weights:dump', weights)
    io.dump(fpath, weights)
    return fpath
示例#2
0
    def dump_snapshot_on_epoch_after(trainer):
        if trainer.epoch % save_interval != 0:
            return

        snapshot_dir = get_snapshot_dir()
        snapshot = trainer.dump_snapshot()
        fpath = osp.join(snapshot_dir, 'epoch_{}'.format(trainer.epoch) + __snapshot_ext__)
        io.mkdir(osp.dirname(fpath))
        io.dump(fpath, snapshot)

        fpath_aliased = []

        fpath_last = osp.join(snapshot_dir, 'last_epoch' + __snapshot_ext__)
        fpath_aliased.append(fpath_last)

        best_loss = trainer.runtime.get('best_loss', np.inf)
        current_loss = trainer.runtime.get('loss', None)
        if current_loss is not None and best_loss > current_loss:
            trainer.runtime['best_loss'] = current_loss

            fpath_best_loss = osp.join(snapshot_dir, 'best_loss' + __snapshot_ext__)
            fpath_aliased.append(fpath_best_loss)

        best_error = trainer.runtime.get('best_error', np.inf)
        current_error = trainer.runtime.get('error', None)
        if current_error is not None and best_error > current_error:
            trainer.runtime['best_error'] = current_error

            fpath_best_error = osp.join(snapshot_dir, 'best_error' + __snapshot_ext__)
            fpath_aliased.append(fpath_best_error)

        io.link(fpath, *fpath_aliased)

        logger.info('Model at epoch {} dumped to {}.\n(Also alias: {}).'.format(
            trainer.epoch, fpath, ', '.join(fpath_aliased)))
示例#3
0
    def post_preference(self, uid, pref):
        dirname = _compose_dir(uid)
        pair = io.load(osp.join(dirname, 'pair.pkl'))
        io.dump(osp.join(dirname, 'pref.txt'), str(pref))

        logger.info('Post preference uid={}, pref={}.'.format(uid, pref))

        data = TrainingData(pair.t1_state, pair.t1_action, pair.t2_state, pair.t2_action, pref)

        # Lock acquired inside this function call.
        self._rpredictor.add_training_data(data)
示例#4
0
    def _process(self, pair):
        uid = uuid.uuid4().hex
        dirname = _compose_dir(uid)
        io.mkdir(dirname)
       
        # dump the file for displaying
        if self._tformat == 'GIF':
            _save_gif(pair.t1_observation, osp.join(dirname, '1.gif'))
            _save_gif(pair.t2_observation, osp.join(dirname, '2.gif'))
        else:
            raise ValueError('Unknown trajectory format: {}'.format(self._tformat))

        # cleanup
        pair = TrajectoryPair(pair.t1_state, None, pair.t1_action, pair.t2_state, None, pair.t2_action)
        # dump the raw pair
        io.dump(osp.join(dirname, 'pair.pkl'), pair)

        return uid, pair
示例#5
0
def dump_weights_file(env, fpath):
    fpath = io.assert_extension(fpath, __weights_ext__)
    weights = env.network.fetch_all_variables_dict()
    io.dump(fpath, weights)
    return fpath