示例#1
0
 def test_variational_eta_fail(self):
     stan = os.path.join(DATAFILES_PATH, 'variational',
                         'eta_should_fail.stan')
     model = CmdStanModel(stan_file=stan)
     with self.assertRaisesRegex(RuntimeError,
                                 'algorithm may not have converged'):
         model.variational(algorithm='meanfield', seed=12345)
    def test_variational_good(self):
        stan = os.path.join(datafiles_path, 'variational',
                            'eta_should_be_big.stan')
        model = CmdStanModel(stan_file=stan)
        vi = model.variational(algorithm='meanfield', seed=12345)
        self.assertEqual(vi.column_names,
                         ('lp__', 'log_p__', 'log_g__', 'mu.1', 'mu.2'))

        self.assertAlmostEqual(vi.variational_params_np[3], 31.0418, places=2)
        self.assertAlmostEqual(vi.variational_params_np[4], 27.4463, places=2)

        self.assertAlmostEqual(vi.variational_params_dict['mu.1'],
                               31.0418,
                               places=2)
        self.assertAlmostEqual(vi.variational_params_dict['mu.2'],
                               27.4463,
                               places=2)

        self.assertEqual(vi.variational_params_np[0],
                         vi.variational_params_pd['lp__'][0])
        self.assertEqual(vi.variational_params_np[3],
                         vi.variational_params_pd['mu.1'][0])
        self.assertEqual(vi.variational_params_np[4],
                         vi.variational_params_pd['mu.2'][0])

        self.assertEqual(vi.variational_sample.shape, (1000, 5))
 def test_variational_eta_small(self):
     stan = os.path.join(datafiles_path, 'variational',
                         'eta_should_be_small.stan')
     model = CmdStanModel(stan_file=stan)
     vi = model.variational(algorithm='meanfield', seed=12345)
     self.assertEqual(vi.column_names,
                      ('lp__', 'log_p__', 'log_g__', 'mu.1', 'mu.2'))
     self.assertAlmostEqual(fabs(vi.variational_params_dict['mu.1']),
                            0.08,
                            places=1)
     self.assertAlmostEqual(fabs(vi.variational_params_dict['mu.2']),
                            0.09,
                            places=1)
     self.assertTrue(True)
示例#4
0
 def test_variational_eta_small(self):
     stan = os.path.join(DATAFILES_PATH, 'variational',
                         'eta_should_be_small.stan')
     model = CmdStanModel(stan_file=stan)
     variational = model.variational(algorithm='meanfield', seed=12345)
     self.assertEqual(
         variational.column_names,
         ('lp__', 'log_p__', 'log_g__', 'mu[1]', 'mu[2]'),
     )
     self.assertAlmostEqual(fabs(
         variational.variational_params_dict['mu[1]']),
                            0.08,
                            places=1)
     self.assertAlmostEqual(fabs(
         variational.variational_params_dict['mu[2]']),
                            0.09,
                            places=1)
     self.assertTrue(True)
示例#5
0
    cmdstanpy.install_cmdstan()

    # instantiate, compile model
    model_path = Path(args.models_path) / ('%s.stan' % args.model_name)
    sicr_model = CmdStanModel(stan_file=model_path)

    # create output directory for std out files cmdstan produces
    save_dir = Path(args.fits_path)
    save_dir.mkdir(parents=True, exist_ok=True)

    output_dir = save_dir / 'std_out'
    output_dir.mkdir(parents=True, exist_ok=True)
    # run CmdStan's variational inference method, returns object `CmdStanVB`

    sicr_model_vb = sicr_model.variational(data=stan_data,
                                           algorithm=args.advi_algorithm,
                                           grad_samples=4000,
                                           elbo_samples=4000,
                                           output_samples=4000,
                                           eta=args.advi_eta,
                                           adapt_iter=args.advi_adapt_iter,
                                           output_dir=output_dir)
    sicr_model_vb.variational_sample.shape

    vb_results = sicr_model_vb.variational_params_dict  # only gives means
    vb_df = pd.DataFrame.from_dict(vb_results, orient="index")
    print(vb_df)
    # save_path = save_dir / ("%s_%s_ADVI_means.csv" % (args.model_name, args.roi))
    # vb_df.to_csv(save_path)