def main_train(trainer): from tartist.plugins.trainer_enhancer import summary summary.enable_summary_history(trainer, extra_summary_types={ 'async/train/score': 'async_scalar', 'async/inference/score': 'async_scalar' }) summary.enable_echo_summary_scalar(trainer, summary_spec={ 'async/train/score': ['avg', 'max'], 'async/inference/score': ['avg', 'max'] }) from tartist.plugins.trainer_enhancer import progress progress.enable_epoch_progress(trainer) from tartist.plugins.trainer_enhancer import snapshot snapshot.enable_snapshot_saver(trainer) from tartist.core import register_event from common_hp_a3c import main_inference_play_multithread def on_epoch_after(trainer): if trainer.epoch > 0 and trainer.epoch % 2 == 0: main_inference_play_multithread(trainer, make_player=make_player) register_event(trainer, 'epoch:after', on_epoch_after, priority=5) trainer.train()
def main_train(trainer): from tartist.app.rl.utils.adv import GAEComputer from tartist.random.sampler import SimpleBatchSampler trainer.set_adv_computer( GAEComputer(get_env('ppo.gamma'), get_env('ppo.gae.lambda'))) trainer.set_batch_sampler( SimpleBatchSampler(get_env('trainer.batch_size'), get_env('trainer.data_repeat'))) # Register plugins. from tartist.plugins.trainer_enhancer import summary summary.enable_summary_history(trainer, extra_summary_types={ 'inference/score': 'async_scalar', }) summary.enable_echo_summary_scalar( trainer, summary_spec={'inference/score': ['avg', 'max']}) from tartist.plugins.trainer_enhancer import progress progress.enable_epoch_progress(trainer) from tartist.plugins.trainer_enhancer import snapshot snapshot.enable_snapshot_saver(trainer, save_interval=1) def on_epoch_after(trainer): if trainer.epoch > 0 and trainer.epoch % 2 == 0: main_inference_play_multithread(trainer) # This one should run before monitor. trainer.register_event('epoch:after', on_epoch_after, priority=5) trainer.train()
def main_train(trainer): from tartist.plugins.trainer_enhancer import summary summary.enable_summary_history(trainer) summary.enable_echo_summary_scalar(trainer) summary.set_error_summary_key(trainer, 'error') from tartist.plugins.trainer_enhancer import progress progress.enable_epoch_progress(trainer) from tartist.plugins.trainer_enhancer import snapshot snapshot.enable_snapshot_saver(trainer) from tartist.plugins.trainer_enhancer import inference inference.enable_inference_runner(trainer, make_dataflow_inference) from tartist.core import register_event def on_epoch_after(trainer): if trainer.epoch == 5: trainer.optimizer.set_learning_rate( trainer.optimizer.learning_rate * 0.1) register_event(trainer, 'epoch:after', on_epoch_after) trainer.train()
def main_train(trainer): from tartist.plugins.trainer_enhancer import summary summary.enable_summary_history(trainer) summary.enable_echo_summary_scalar(trainer) from tartist.plugins.trainer_enhancer import progress progress.enable_epoch_progress(trainer) from tartist.plugins.trainer_enhancer import snapshot snapshot.enable_snapshot_saver(trainer) trainer.train()
def _default_main_train(trainer): # TODO:: Early stop def on_optimization_before(trainer): # clear the validation loss history trainer.runtime['validation_losses'] = [] # compile the function for inference trainer.inference_func = trainer.env.make_func() trainer.inference_func.compile(trainer.network.loss) def on_epoch_after(trainer): # compute the validation loss sum_loss, nr_data = 0, 0 for data in trainer.dataflow_validation: sum_loss += trainer.inference_func(**data) nr_data += 1 avg_loss = sum_loss / nr_data logger.info('Epoch: {}: average validation loss = {}.'.format( trainer.epoch, avg_loss)) trainer.runtime['validation_losses'].append(avg_loss) # test whether early stop losses = trainer.runtime['validation_losses'][-6:] if len(losses) <= 1: return # 2 out of 5 nr_loss_increase = 0 for a, b in zip(losses[:-1], losses[1:]): if b > a: nr_loss_increase += 1 if nr_loss_increase >= 2: # acquire early stop logger.critical( 'Validation loss is keeping increasing: acquire early stop.') trainer.stop() from tartist.plugins.trainer_enhancer import summary summary.enable_summary_history(trainer) summary.enable_echo_summary_scalar(trainer, enable_json=False, enable_tensorboard=False) from tartist.plugins.trainer_enhancer import progress progress.enable_epoch_progress(trainer) # early stop related hooks trainer.register_event('optimization:before', on_optimization_before) trainer.register_event('epoch:after', on_epoch_after) trainer.train()
def main_train(trainer): # Register plugins. from tartist.plugins.trainer_enhancer import summary summary.enable_summary_history(trainer, extra_summary_types={ 'inference/score': 'async_scalar', 'train/exp_epsilon': 'async_scalar' }) summary.enable_echo_summary_scalar( trainer, summary_spec={'inference/score': ['avg', 'max']}) from tartist.plugins.trainer_enhancer import progress progress.enable_epoch_progress(trainer) from tartist.plugins.trainer_enhancer import snapshot snapshot.enable_snapshot_saver(trainer, save_interval=2) def set_exp_epsilon(trainer_or_env, value): r = trainer_or_env.runtime if r.get('exp_epsilon', None) != value: logger.critical('Setting exploration epsilon to {}'.format(value)) r['exp_epsilon'] = value schedule = [(0, 0.1), (10, 0.1), (250, 0.01), (1e9, 0.01)] def schedule_exp_epsilon(trainer): # `trainer.runtime` is synchronous with `trainer.env.runtime`. last_v = None for e, v in schedule: if trainer.epoch < e: set_exp_epsilon(trainer, last_v) break last_v = v def on_epoch_after(trainer): if trainer.epoch > 0 and trainer.epoch % 2 == 0: main_inference_play_multithread(trainer) # Summarize the exp epsilon. mgr = trainer.runtime.get('summary_histories', None) if mgr is not None: mgr.put_async_scalar('train/exp_epsilon', trainer.runtime['exp_epsilon']) # This one should run before monitor. trainer.register_event('epoch:before', schedule_exp_epsilon, priority=5) trainer.register_event('epoch:after', on_epoch_after, priority=5) trainer.train()
def main_train(trainer): from tartist.plugins.trainer_enhancer import summary summary.enable_summary_history(trainer) summary.enable_echo_summary_scalar(trainer) from tartist.plugins.trainer_enhancer import progress progress.enable_epoch_progress(trainer) from tartist.plugins.trainer_enhancer import snapshot snapshot.enable_snapshot_saver(trainer) from tartist.plugins.trainer_enhancer import inference inference.enable_inference_runner(trainer, make_dataflow_inference) trainer.train()
def main_train(trainer): # Compose the evaluator player = make_player() def evaluate_train(trainer, p=player): return _evaluate(player=p, func=trainer.pred_func) trainer.set_evaluator(evaluate_train) # Register plugins from tartist.plugins.trainer_enhancer import summary summary.enable_summary_history(trainer, extra_summary_types={ 'inference/score': 'async_scalar', }) summary.enable_echo_summary_scalar( trainer, summary_spec={'inference/score': ['avg', 'max']}) from tartist.plugins.trainer_enhancer import progress progress.enable_epoch_progress(trainer) from tartist.plugins.trainer_enhancer import snapshot snapshot.enable_snapshot_saver(trainer, save_interval=1) def on_epoch_before(trainer): v = max(5 - trainer.epoch / 10, 0) trainer.optimizer.param_std += v def on_epoch_after(trainer): if trainer.epoch > 0 and trainer.epoch % 2 == 0: main_inference_play_multithread(trainer) trainer.register_event('epoch:before', on_epoch_before, priority=5) # This one should run before monitor. trainer.register_event('epoch:after', on_epoch_after, priority=5) trainer.train()