def main(_):

    util.print('Loading model {} with dataset {}.'.format(
        FLAGS.model, FLAGS.dataset))

    if FLAGS.model == 'radon':
        model_config = models.get_radon(state_code=FLAGS.dataset)
    elif FLAGS.model == 'radon_stddvs':
        model_config = models.get_radon_model_stddvs(state_code=FLAGS.dataset)
    elif FLAGS.model == '8schools':
        model_config = models.get_eight_schools()
    elif FLAGS.model == 'german_credit_gammascale':
        model_config = models.get_german_credit_gammascale()
    elif FLAGS.model == 'german_credit_lognormalcentered':
        model_config = models.get_german_credit_lognormalcentered()
    else:
        raise Exception('unknown model {}'.format(FLAGS.model))

    description = FLAGS.model + '_{}'.format(FLAGS.dataset)

    experiments_dir = os.path.join(
        FLAGS.results_dir,
        'num_leapfrog_steps={}'.format(FLAGS.num_leapfrog_steps))
    if not tf.gfile.Exists(experiments_dir):
        tf.gfile.MakeDirs(experiments_dir)

    if FLAGS.method == 'baseline':
        run_baseline(description,
                     model_config=model_config,
                     experiments_dir=experiments_dir,
                     num_samples=FLAGS.num_samples,
                     burnin=FLAGS.burnin,
                     num_adaptation_steps=FLAGS.num_adaptation_steps,
                     num_optimization_steps=FLAGS.num_optimization_steps,
                     tau=FLAGS.tau,
                     num_leapfrog_steps=FLAGS.num_leapfrog_steps,
                     description=description)
    elif FLAGS.method == 'vip':
        run_vip(description,
                model_config=model_config,
                experiments_dir=experiments_dir,
                use_iaf_posterior=FLAGS.use_iaf_posterior,
                num_samples=FLAGS.num_samples,
                burnin=FLAGS.burnin,
                num_adaptation_steps=FLAGS.num_adaptation_steps,
                num_optimization_steps=FLAGS.num_optimization_steps,
                num_mc_samples=FLAGS.num_mc_samples,
                tau=FLAGS.tau,
                num_leapfrog_steps=FLAGS.num_leapfrog_steps,
                description=description)
    else:
        raise Exception('No such method')
Beispiel #2
0
 def test_eight_schools(self):
     (model, model_args, observed, to_cp, to_ncp,
      make_to_cp) = models.get_eight_schools()
     self._sanity_check_conversion(model, model_args, observed, to_cp,
                                   to_ncp, make_to_cp)