コード例 #1
0
def main_train(trainer):
    from tartist.plugins.trainer_enhancer import summary
    summary.enable_summary_history(trainer,
                                   extra_summary_types={
                                       'async/train/score': 'async_scalar',
                                       'async/inference/score': 'async_scalar'
                                   })
    summary.enable_echo_summary_scalar(trainer,
                                       summary_spec={
                                           'async/train/score': ['avg', 'max'],
                                           'async/inference/score':
                                           ['avg', 'max']
                                       })

    from tartist.plugins.trainer_enhancer import progress
    progress.enable_epoch_progress(trainer)

    from tartist.plugins.trainer_enhancer import snapshot
    snapshot.enable_snapshot_saver(trainer)

    from tartist.core import register_event
    from common_hp_a3c import main_inference_play_multithread

    def on_epoch_after(trainer):
        if trainer.epoch > 0 and trainer.epoch % 2 == 0:
            main_inference_play_multithread(trainer, make_player=make_player)

    register_event(trainer, 'epoch:after', on_epoch_after, priority=5)

    trainer.train()
コード例 #2
0
ファイル: desc_mnist.py プロジェクト: vacancy/TensorArtist
def main_train(trainer):
    from tartist.plugins.trainer_enhancer import summary
    summary.enable_summary_history(trainer)
    summary.enable_echo_summary_scalar(trainer)
    summary.set_error_summary_key(trainer, 'error')

    from tartist.plugins.trainer_enhancer import progress
    progress.enable_epoch_progress(trainer)

    from tartist.plugins.trainer_enhancer import snapshot
    snapshot.enable_snapshot_saver(trainer)

    from tartist.plugins.trainer_enhancer import inference
    inference.enable_inference_runner(trainer, make_dataflow_inference)

    from tartist.core import register_event

    def on_epoch_after(trainer):
        if trainer.epoch == 5:
            trainer.optimizer.set_learning_rate(
                trainer.optimizer.learning_rate * 0.1)

    register_event(trainer, 'epoch:after', on_epoch_after)

    trainer.train()
コード例 #3
0
def enable_param_clippping(trainer):
    # apply clip on params of discriminator
    limit = get_env('trainer.clip_limit', 0.01)
    ops = []
    with trainer.env.as_default():
        sess = trainer.env.session
        var_list = tf.trainable_variables()
        for v in var_list:
            if v.name.startswith(GANGraphKeys.DISCRIMINATOR_VARIABLES + '/'):
                ops.append(v.assign(tf.clip_by_value(v, -limit, limit)))
    op = tf.group(*ops)

    def do_clip_params(trainer, inp, out):
        trainer.env.session.run(op)

    from tartist.core import register_event
    register_event(trainer, 'iter:after', do_clip_params)