Exemplo n.º 1
0
    def dump_snapshot_on_epoch_after(trainer):
        if trainer.epoch % save_interval != 0:
            return

        snapshot_dir = get_snapshot_dir()
        snapshot = trainer.dump_snapshot()
        fpath = osp.join(snapshot_dir, 'epoch_{}'.format(trainer.epoch) + __snapshot_ext__)
        io.mkdir(osp.dirname(fpath))
        io.dump(fpath, snapshot)

        fpath_aliased = []

        fpath_last = osp.join(snapshot_dir, 'last_epoch' + __snapshot_ext__)
        fpath_aliased.append(fpath_last)

        best_loss = trainer.runtime.get('best_loss', np.inf)
        current_loss = trainer.runtime.get('loss', None)
        if current_loss is not None and best_loss > current_loss:
            trainer.runtime['best_loss'] = current_loss

            fpath_best_loss = osp.join(snapshot_dir, 'best_loss' + __snapshot_ext__)
            fpath_aliased.append(fpath_best_loss)

        best_error = trainer.runtime.get('best_error', np.inf)
        current_error = trainer.runtime.get('error', None)
        if current_error is not None and best_error > current_error:
            trainer.runtime['best_error'] = current_error

            fpath_best_error = osp.join(snapshot_dir, 'best_error' + __snapshot_ext__)
            fpath_aliased.append(fpath_best_error)

        io.link(fpath, *fpath_aliased)

        logger.info('Model at epoch {} dumped to {}.\n(Also alias: {}).'.format(
            trainer.epoch, fpath, ', '.join(fpath_aliased)))
Exemplo n.º 2
0
def dump_pack(cfg, pack):
    pickleable = pack.make_pickleable()

    name = cfg.name + '-'
    name += time.strftime('%Y%m%d-%H%M%S')
    name += '.replay.pkl'
    io.mkdir('replays')
    with open(osp.join('replays', name), 'wb') as f:
        pickle.dump(pickleable, f, pickle.HIGHEST_PROTOCOL)
    logger.critical('replay written to replays/{}'.format(name))
Exemplo n.º 3
0
    def __init__(self, name, dump_dir=None, force_dump=False, state_mode='DEFAULT'):
        super().__init__()

        with get_env_lock():
            self._gym = gym.make(name)

        if dump_dir:
            io.mkdir(dump_dir)
            self._gym = gym.wrappers.Monitor(self._gym, dump_dir, force=force_dump)

        assert state_mode in ('DEFAULT', 'RENDER', 'BOTH')
        self._state_mode = state_mode
Exemplo n.º 4
0
    def _process(self, pair):
        uid = uuid.uuid4().hex
        dirname = _compose_dir(uid)
        io.mkdir(dirname)
       
        # dump the file for displaying
        if self._tformat == 'GIF':
            _save_gif(pair.t1_observation, osp.join(dirname, '1.gif'))
            _save_gif(pair.t2_observation, osp.join(dirname, '2.gif'))
        else:
            raise ValueError('Unknown trajectory format: {}'.format(self._tformat))

        # cleanup
        pair = TrajectoryPair(pair.t1_state, None, pair.t1_action, pair.t2_state, None, pair.t2_action)
        # dump the raw pair
        io.dump(osp.join(dirname, 'pair.pkl'), pair)

        return uid, pair
Exemplo n.º 5
0
 def tensorboard_summary_enable(trainer, tb_path=tensorboard_path):
     if tb_path is None:
         tb_path = osp.join(get_env('dir.root'), 'tensorboard')
     restored = 'restore_snapshot' in trainer.runtime
     if osp.exists(tb_path) and not restored:
         logger.warn('Removing old tensorboard directory: {}.'.format(tb_path))
         shutil.rmtree(tb_path)
     io.mkdir(tb_path)
     trainer.runtime['tensorboard_summary_path'] = tb_path
     trainer._tensorboard_writer = tf.summary.FileWriter(tb_path, graph=trainer.env.graph)
     if enable_tensorboard_web:
         port = random.randrange(49152, 65536.)
         port = trainer.runtime.get('tensorboard_web_port', port)
         trainer._tensorboard_webserver = threading.Thread(
                 target=_tensorboard_webserver_thread, args=['tensorboard', '--logdir', tb_path, '--port', str(port)],
                 daemon=True)
         trainer._tensorboard_webserver.start()
         trainer.runtime['tensorboard_web_port'] = port