Пример #1
0
    def test_validate_good_run(self):
        # construct fit using existing sampler output
        exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION)
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        output = os.path.join(goodfiles_path, 'bern')
        sampler_args = SamplerArgs(sampling_iters=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        fit = StanFit(args=cmdstan_args, chains=4)
        retcodes = fit._retcodes
        for i in range(len(retcodes)):
            fit._set_retcode(i, 0)
        self.assertTrue(fit._check_retcodes())
        fit._check_console_msgs()
        fit._validate_csv_files()
        self.assertEqual(4, fit.chains)
        self.assertEqual(100, fit.draws)
        self.assertEqual(8, len(fit.column_names))
        self.assertEqual('lp__', fit.column_names[0])

        df = fit.get_drawset()
        self.assertEqual(df.shape,
                         (fit.chains * fit.draws, len(fit.column_names)))
        _ = fit.summary()

        # TODO - use cmdstan test files instead
        expected = '\n'.join([
            'Checking sampler transitions treedepth.',
            'Treedepth satisfactory for all transitions.',
            '\nChecking sampler transitions for divergences.',
            'No divergent transitions found.',
            '\nChecking E-BFMI - sampler transitions HMC potential energy.',
            'E-BFMI satisfactory for all transitions.',
            '\nEffective sample size satisfactory.',
        ])
        self.assertIn(expected, fit.diagnose().replace("\r\n", "\n"))
Пример #2
0
    def test_validate_bad_run(self):
        exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION)
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        sampler_args = SamplerArgs(sampling_iters=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)

        # some chains had errors
        output = os.path.join(badfiles_path, 'bad-transcript-bern')
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        fit = StanFit(args=cmdstan_args, chains=4)
        with self.assertRaisesRegex(Exception, 'Exception'):
            fit._check_console_msgs()

        # csv file headers inconsistent
        output = os.path.join(badfiles_path, 'bad-hdr-bern')
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        fit = StanFit(args=cmdstan_args, chains=4)
        retcodes = fit._retcodes
        for i in range(len(retcodes)):
            fit._set_retcode(i, 0)
        self.assertTrue(fit._check_retcodes())
        with self.assertRaisesRegex(ValueError, 'header mismatch'):
            fit._validate_csv_files()

        # bad draws
        output = os.path.join(badfiles_path, 'bad-draws-bern')
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        fit = StanFit(args=cmdstan_args, chains=4)
        retcodes = fit._retcodes
        for i in range(len(retcodes)):
            fit._set_retcode(i, 0)
        self.assertTrue(fit._check_retcodes())
        with self.assertRaisesRegex(ValueError, 'draws'):
            fit._validate_csv_files()

        # mismatch - column headers, draws
        output = os.path.join(badfiles_path, 'bad-cols-bern')
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        fit = StanFit(args=cmdstan_args, chains=4)
        retcodes = fit._retcodes
        for i in range(len(retcodes)):
            fit._set_retcode(i, 0)
        self.assertTrue(fit._check_retcodes())
        with self.assertRaisesRegex(ValueError, 'bad draw'):
            fit._validate_csv_files()