コード例 #1
0
def main_train(trainer):
    from tartist.app.rl.utils.adv import GAEComputer
    from tartist.random.sampler import SimpleBatchSampler
    trainer.set_adv_computer(
        GAEComputer(get_env('ppo.gamma'), get_env('ppo.gae.lambda')))
    trainer.set_batch_sampler(
        SimpleBatchSampler(get_env('trainer.batch_size'),
                           get_env('trainer.data_repeat')))

    # Register plugins.
    from tartist.plugins.trainer_enhancer import summary
    summary.enable_summary_history(trainer,
                                   extra_summary_types={
                                       'inference/score': 'async_scalar',
                                   })
    summary.enable_echo_summary_scalar(
        trainer, summary_spec={'inference/score': ['avg', 'max']})

    from tartist.plugins.trainer_enhancer import progress
    progress.enable_epoch_progress(trainer)

    from tartist.plugins.trainer_enhancer import snapshot
    snapshot.enable_snapshot_saver(trainer, save_interval=1)

    def on_epoch_after(trainer):
        if trainer.epoch > 0 and trainer.epoch % 2 == 0:
            main_inference_play_multithread(trainer)

    # This one should run before monitor.
    trainer.register_event('epoch:after', on_epoch_after, priority=5)

    trainer.train()
コード例 #2
0
def main_train(trainer):
    from tartist.plugins.trainer_enhancer import summary
    summary.enable_summary_history(trainer,
                                   extra_summary_types={
                                       'async/train/score': 'async_scalar',
                                       'async/inference/score': 'async_scalar'
                                   })
    summary.enable_echo_summary_scalar(trainer,
                                       summary_spec={
                                           'async/train/score': ['avg', 'max'],
                                           'async/inference/score':
                                           ['avg', 'max']
                                       })

    from tartist.plugins.trainer_enhancer import progress
    progress.enable_epoch_progress(trainer)

    from tartist.plugins.trainer_enhancer import snapshot
    snapshot.enable_snapshot_saver(trainer)

    from tartist.core import register_event
    from common_hp_a3c import main_inference_play_multithread

    def on_epoch_after(trainer):
        if trainer.epoch > 0 and trainer.epoch % 2 == 0:
            main_inference_play_multithread(trainer, make_player=make_player)

    register_event(trainer, 'epoch:after', on_epoch_after, priority=5)

    trainer.train()
コード例 #3
0
ファイル: desc_mnist.py プロジェクト: vacancy/TensorArtist
def main_train(trainer):
    from tartist.plugins.trainer_enhancer import summary
    summary.enable_summary_history(trainer)
    summary.enable_echo_summary_scalar(trainer)
    summary.set_error_summary_key(trainer, 'error')

    from tartist.plugins.trainer_enhancer import progress
    progress.enable_epoch_progress(trainer)

    from tartist.plugins.trainer_enhancer import snapshot
    snapshot.enable_snapshot_saver(trainer)

    from tartist.plugins.trainer_enhancer import inference
    inference.enable_inference_runner(trainer, make_dataflow_inference)

    from tartist.core import register_event

    def on_epoch_after(trainer):
        if trainer.epoch == 5:
            trainer.optimizer.set_learning_rate(
                trainer.optimizer.learning_rate * 0.1)

    register_event(trainer, 'epoch:after', on_epoch_after)

    trainer.train()
コード例 #4
0
def main_train(trainer):
    from tartist.plugins.trainer_enhancer import summary
    summary.enable_summary_history(trainer)
    summary.enable_echo_summary_scalar(trainer)

    from tartist.plugins.trainer_enhancer import progress
    progress.enable_epoch_progress(trainer)

    from tartist.plugins.trainer_enhancer import snapshot
    snapshot.enable_snapshot_saver(trainer)

    trainer.train()
コード例 #5
0
def main_train(trainer):
    # Register plugins.
    from tartist.plugins.trainer_enhancer import summary
    summary.enable_summary_history(trainer,
                                   extra_summary_types={
                                       'inference/score': 'async_scalar',
                                       'train/exp_epsilon': 'async_scalar'
                                   })
    summary.enable_echo_summary_scalar(
        trainer, summary_spec={'inference/score': ['avg', 'max']})

    from tartist.plugins.trainer_enhancer import progress
    progress.enable_epoch_progress(trainer)

    from tartist.plugins.trainer_enhancer import snapshot
    snapshot.enable_snapshot_saver(trainer, save_interval=2)

    def set_exp_epsilon(trainer_or_env, value):
        r = trainer_or_env.runtime
        if r.get('exp_epsilon', None) != value:
            logger.critical('Setting exploration epsilon to {}'.format(value))
            r['exp_epsilon'] = value

    schedule = [(0, 0.1), (10, 0.1), (250, 0.01), (1e9, 0.01)]

    def schedule_exp_epsilon(trainer):
        # `trainer.runtime` is synchronous with `trainer.env.runtime`.
        last_v = None
        for e, v in schedule:
            if trainer.epoch < e:
                set_exp_epsilon(trainer, last_v)
                break
            last_v = v

    def on_epoch_after(trainer):
        if trainer.epoch > 0 and trainer.epoch % 2 == 0:
            main_inference_play_multithread(trainer)

        # Summarize the exp epsilon.
        mgr = trainer.runtime.get('summary_histories', None)
        if mgr is not None:
            mgr.put_async_scalar('train/exp_epsilon',
                                 trainer.runtime['exp_epsilon'])

    # This one should run before monitor.
    trainer.register_event('epoch:before', schedule_exp_epsilon, priority=5)
    trainer.register_event('epoch:after', on_epoch_after, priority=5)

    trainer.train()
コード例 #6
0
def main_train(trainer):
    from tartist.plugins.trainer_enhancer import summary
    summary.enable_summary_history(trainer)
    summary.enable_echo_summary_scalar(trainer)

    from tartist.plugins.trainer_enhancer import progress
    progress.enable_epoch_progress(trainer)

    from tartist.plugins.trainer_enhancer import snapshot
    snapshot.enable_snapshot_saver(trainer)

    from tartist.plugins.trainer_enhancer import inference
    inference.enable_inference_runner(trainer, make_dataflow_inference)

    trainer.train()
コード例 #7
0
def main_train(trainer):
    # Compose the evaluator
    player = make_player()

    def evaluate_train(trainer, p=player):
        return _evaluate(player=p, func=trainer.pred_func)

    trainer.set_evaluator(evaluate_train)

    # Register plugins
    from tartist.plugins.trainer_enhancer import summary
    summary.enable_summary_history(trainer,
                                   extra_summary_types={
                                       'inference/score': 'async_scalar',
                                   })
    summary.enable_echo_summary_scalar(
        trainer, summary_spec={'inference/score': ['avg', 'max']})

    from tartist.plugins.trainer_enhancer import progress
    progress.enable_epoch_progress(trainer)

    from tartist.plugins.trainer_enhancer import snapshot
    snapshot.enable_snapshot_saver(trainer, save_interval=1)

    def on_epoch_before(trainer):
        v = max(5 - trainer.epoch / 10, 0)
        trainer.optimizer.param_std += v

    def on_epoch_after(trainer):
        if trainer.epoch > 0 and trainer.epoch % 2 == 0:
            main_inference_play_multithread(trainer)

    trainer.register_event('epoch:before', on_epoch_before, priority=5)
    # This one should run before monitor.
    trainer.register_event('epoch:after', on_epoch_after, priority=5)

    trainer.train()