예제 #1
0
def main_train(trainer):
    from tartist.plugins.trainer_enhancer import summary
    summary.enable_summary_history(trainer,
                                   extra_summary_types={
                                       'async/train/score': 'async_scalar',
                                       'async/inference/score': 'async_scalar'
                                   })
    summary.enable_echo_summary_scalar(trainer,
                                       summary_spec={
                                           'async/train/score': ['avg', 'max'],
                                           'async/inference/score':
                                           ['avg', 'max']
                                       })

    from tartist.plugins.trainer_enhancer import progress
    progress.enable_epoch_progress(trainer)

    from tartist.plugins.trainer_enhancer import snapshot
    snapshot.enable_snapshot_saver(trainer)

    from tartist.core import register_event
    from common_hp_a3c import main_inference_play_multithread

    def on_epoch_after(trainer):
        if trainer.epoch > 0 and trainer.epoch % 2 == 0:
            main_inference_play_multithread(trainer, make_player=make_player)

    register_event(trainer, 'epoch:after', on_epoch_after, priority=5)

    trainer.train()
def main_train(trainer):
    from tartist.app.rl.utils.adv import GAEComputer
    from tartist.random.sampler import SimpleBatchSampler
    trainer.set_adv_computer(
        GAEComputer(get_env('ppo.gamma'), get_env('ppo.gae.lambda')))
    trainer.set_batch_sampler(
        SimpleBatchSampler(get_env('trainer.batch_size'),
                           get_env('trainer.data_repeat')))

    # Register plugins.
    from tartist.plugins.trainer_enhancer import summary
    summary.enable_summary_history(trainer,
                                   extra_summary_types={
                                       'inference/score': 'async_scalar',
                                   })
    summary.enable_echo_summary_scalar(
        trainer, summary_spec={'inference/score': ['avg', 'max']})

    from tartist.plugins.trainer_enhancer import progress
    progress.enable_epoch_progress(trainer)

    from tartist.plugins.trainer_enhancer import snapshot
    snapshot.enable_snapshot_saver(trainer, save_interval=1)

    def on_epoch_after(trainer):
        if trainer.epoch > 0 and trainer.epoch % 2 == 0:
            main_inference_play_multithread(trainer)

    # This one should run before monitor.
    trainer.register_event('epoch:after', on_epoch_after, priority=5)

    trainer.train()
예제 #3
0
def main_train(trainer):
    from tartist.plugins.trainer_enhancer import summary
    summary.enable_summary_history(trainer)
    summary.enable_echo_summary_scalar(trainer)
    summary.set_error_summary_key(trainer, 'error')

    from tartist.plugins.trainer_enhancer import progress
    progress.enable_epoch_progress(trainer)

    from tartist.plugins.trainer_enhancer import snapshot
    snapshot.enable_snapshot_saver(trainer)

    from tartist.plugins.trainer_enhancer import inference
    inference.enable_inference_runner(trainer, make_dataflow_inference)

    from tartist.core import register_event

    def on_epoch_after(trainer):
        if trainer.epoch == 5:
            trainer.optimizer.set_learning_rate(
                trainer.optimizer.learning_rate * 0.1)

    register_event(trainer, 'epoch:after', on_epoch_after)

    trainer.train()
예제 #4
0
def main_train(trainer):
    from tartist.plugins.trainer_enhancer import summary
    summary.enable_summary_history(trainer)
    summary.enable_echo_summary_scalar(trainer)

    from tartist.plugins.trainer_enhancer import progress
    progress.enable_epoch_progress(trainer)

    from tartist.plugins.trainer_enhancer import snapshot
    snapshot.enable_snapshot_saver(trainer)

    trainer.train()
예제 #5
0
def _default_main_train(trainer):
    # TODO:: Early stop
    def on_optimization_before(trainer):
        # clear the validation loss history
        trainer.runtime['validation_losses'] = []

        # compile the function for inference
        trainer.inference_func = trainer.env.make_func()
        trainer.inference_func.compile(trainer.network.loss)

    def on_epoch_after(trainer):
        # compute the validation loss
        sum_loss, nr_data = 0, 0
        for data in trainer.dataflow_validation:
            sum_loss += trainer.inference_func(**data)
            nr_data += 1
        avg_loss = sum_loss / nr_data
        logger.info('Epoch: {}: average validation loss = {}.'.format(
            trainer.epoch, avg_loss))

        trainer.runtime['validation_losses'].append(avg_loss)

        # test whether early stop
        losses = trainer.runtime['validation_losses'][-6:]
        if len(losses) <= 1:
            return

        # 2 out of 5
        nr_loss_increase = 0
        for a, b in zip(losses[:-1], losses[1:]):
            if b > a:
                nr_loss_increase += 1

        if nr_loss_increase >= 2:
            # acquire early stop
            logger.critical(
                'Validation loss is keeping increasing: acquire early stop.')
            trainer.stop()

    from tartist.plugins.trainer_enhancer import summary
    summary.enable_summary_history(trainer)
    summary.enable_echo_summary_scalar(trainer,
                                       enable_json=False,
                                       enable_tensorboard=False)

    from tartist.plugins.trainer_enhancer import progress
    progress.enable_epoch_progress(trainer)

    # early stop related hooks
    trainer.register_event('optimization:before', on_optimization_before)
    trainer.register_event('epoch:after', on_epoch_after)

    trainer.train()
def main_train(trainer):
    # Register plugins.
    from tartist.plugins.trainer_enhancer import summary
    summary.enable_summary_history(trainer,
                                   extra_summary_types={
                                       'inference/score': 'async_scalar',
                                       'train/exp_epsilon': 'async_scalar'
                                   })
    summary.enable_echo_summary_scalar(
        trainer, summary_spec={'inference/score': ['avg', 'max']})

    from tartist.plugins.trainer_enhancer import progress
    progress.enable_epoch_progress(trainer)

    from tartist.plugins.trainer_enhancer import snapshot
    snapshot.enable_snapshot_saver(trainer, save_interval=2)

    def set_exp_epsilon(trainer_or_env, value):
        r = trainer_or_env.runtime
        if r.get('exp_epsilon', None) != value:
            logger.critical('Setting exploration epsilon to {}'.format(value))
            r['exp_epsilon'] = value

    schedule = [(0, 0.1), (10, 0.1), (250, 0.01), (1e9, 0.01)]

    def schedule_exp_epsilon(trainer):
        # `trainer.runtime` is synchronous with `trainer.env.runtime`.
        last_v = None
        for e, v in schedule:
            if trainer.epoch < e:
                set_exp_epsilon(trainer, last_v)
                break
            last_v = v

    def on_epoch_after(trainer):
        if trainer.epoch > 0 and trainer.epoch % 2 == 0:
            main_inference_play_multithread(trainer)

        # Summarize the exp epsilon.
        mgr = trainer.runtime.get('summary_histories', None)
        if mgr is not None:
            mgr.put_async_scalar('train/exp_epsilon',
                                 trainer.runtime['exp_epsilon'])

    # This one should run before monitor.
    trainer.register_event('epoch:before', schedule_exp_epsilon, priority=5)
    trainer.register_event('epoch:after', on_epoch_after, priority=5)

    trainer.train()
예제 #7
0
def main_train(trainer):
    from tartist.plugins.trainer_enhancer import summary
    summary.enable_summary_history(trainer)
    summary.enable_echo_summary_scalar(trainer)

    from tartist.plugins.trainer_enhancer import progress
    progress.enable_epoch_progress(trainer)

    from tartist.plugins.trainer_enhancer import snapshot
    snapshot.enable_snapshot_saver(trainer)

    from tartist.plugins.trainer_enhancer import inference
    inference.enable_inference_runner(trainer, make_dataflow_inference)

    trainer.train()
def main_train(trainer):
    # Compose the evaluator
    player = make_player()

    def evaluate_train(trainer, p=player):
        return _evaluate(player=p, func=trainer.pred_func)

    trainer.set_evaluator(evaluate_train)

    # Register plugins
    from tartist.plugins.trainer_enhancer import summary
    summary.enable_summary_history(trainer,
                                   extra_summary_types={
                                       'inference/score': 'async_scalar',
                                   })
    summary.enable_echo_summary_scalar(
        trainer, summary_spec={'inference/score': ['avg', 'max']})

    from tartist.plugins.trainer_enhancer import progress
    progress.enable_epoch_progress(trainer)

    from tartist.plugins.trainer_enhancer import snapshot
    snapshot.enable_snapshot_saver(trainer, save_interval=1)

    def on_epoch_before(trainer):
        v = max(5 - trainer.epoch / 10, 0)
        trainer.optimizer.param_std += v

    def on_epoch_after(trainer):
        if trainer.epoch > 0 and trainer.epoch % 2 == 0:
            main_inference_play_multithread(trainer)

    trainer.register_event('epoch:before', on_epoch_before, priority=5)
    # This one should run before monitor.
    trainer.register_event('epoch:after', on_epoch_after, priority=5)

    trainer.train()