def test_good(self): args = SamplerArgs( warmup_iters=10, sampling_iters=20, save_warmup=True, thin=7, max_treedepth=15, adapt_delta=0.99, ) args.validate(chains=4) cmd = args.compose(1, cmd=[]) self.assertIn('method=sample', ' '.join(cmd)) self.assertIn('num_warmup=10', ' '.join(cmd)) self.assertIn('num_samples=20', ' '.join(cmd)) self.assertIn('save_warmup=1', ' '.join(cmd)) self.assertIn('thin=7', ' '.join(cmd)) self.assertIn('algorithm=hmc engine=nuts', ' '.join(cmd)) self.assertIn('max_depth=15 adapt delta=0.99', ' '.join(cmd)) args = SamplerArgs(warmup_iters=10) args.validate(chains=4) cmd = args.compose(1, cmd=[]) self.assertIn('method=sample', ' '.join(cmd)) self.assertIn('num_warmup=10', ' '.join(cmd)) self.assertNotIn('num_samples=', ' '.join(cmd)) self.assertNotIn('save_warmup=', ' '.join(cmd)) self.assertNotIn('algorithm=hmc engine=nuts', ' '.join(cmd))
def test_adapt(self): args = SamplerArgs(adapt_engaged=False) args.validate(chains=4) cmd = args.compose(1, cmd=[]) self.assertIn('method=sample algorithm=hmc adapt engaged=0', ' '.join(cmd)) args = SamplerArgs(adapt_engaged=True) args.validate(chains=4) cmd = args.compose(1, cmd=[]) self.assertIn('method=sample algorithm=hmc adapt engaged=1', ' '.join(cmd)) args = SamplerArgs(adapt_init_phase=26, adapt_metric_window=60, adapt_step_size=34) args.validate(chains=4) cmd = args.compose(1, cmd=[]) self.assertIn('method=sample algorithm=hmc adapt', ' '.join(cmd)) self.assertIn('init_buffer=26', ' '.join(cmd)) self.assertIn('window=60', ' '.join(cmd)) self.assertIn('term_buffer=34', ' '.join(cmd)) args = SamplerArgs() args.validate(chains=4) cmd = args.compose(1, cmd=[]) self.assertNotIn('engine=nuts', ' '.join(cmd)) self.assertNotIn('adapt engaged=0', ' '.join(cmd))
def test_adapt(self): args = SamplerArgs(adapt_engaged=False) args.validate(chains=4) cmd = args.compose(1, '') self.assertIn('method=sample algorithm=hmc adapt engaged=0', cmd) args = SamplerArgs(adapt_engaged=True) args.validate(chains=4) cmd = args.compose(1, '') self.assertIn('method=sample algorithm=hmc adapt engaged=1', cmd) args = SamplerArgs() args.validate(chains=4) cmd = args.compose(1, '') self.assertNotIn('engine=nuts', cmd) self.assertNotIn('engaged=1', cmd)
def test_bad(self): args = SamplerArgs(warmup_iters=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(warmup_iters=0, adapt_engaged=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(sampling_iters=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(thin=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(max_treedepth=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(step_size=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(step_size=[1.0, 1.1]) with self.assertRaises(ValueError): args.validate(chains=1) args = SamplerArgs(step_size=[1.0, -1.1]) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_delta=1.1) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_delta=-0.1) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(warmup_iters=100, fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(save_warmup=True, fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(max_treedepth=12, fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(metric='dense', fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(step_size=0.5, fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_delta=0.88, fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2)
def test_args_chains(self): args = SamplerArgs() with self.assertRaises(ValueError): args.validate(chains=None)
def test_args_min(self): args = SamplerArgs() args.validate(chains=4) cmd = args.compose(idx=1, cmd=[]) self.assertIn('method=sample algorithm=hmc', ' '.join(cmd))
def test_fixed_param(self): args = SamplerArgs(fixed_param=True) args.validate(chains=1) cmd = args.compose(0, cmd=[]) self.assertIn('method=sample algorithm=fixed_param', ' '.join(cmd))
def test_metric(self): args = SamplerArgs(metric='dense_e') args.validate(chains=4) cmd = args.compose(1, cmd=[]) self.assertIn('method=sample algorithm=hmc metric=dense_e', ' '.join(cmd)) args = SamplerArgs(metric='dense') args.validate(chains=4) cmd = args.compose(1, cmd=[]) self.assertIn('method=sample algorithm=hmc metric=dense_e', ' '.join(cmd)) args = SamplerArgs(metric='diag_e') args.validate(chains=4) cmd = args.compose(1, cmd=[]) self.assertIn('method=sample algorithm=hmc metric=diag_e', ' '.join(cmd)) args = SamplerArgs(metric='diag') args.validate(chains=4) cmd = args.compose(1, cmd=[]) self.assertIn('method=sample algorithm=hmc metric=diag_e', ' '.join(cmd)) args = SamplerArgs() args.validate(chains=4) cmd = args.compose(1, cmd=[]) self.assertNotIn('metric=', ' '.join(cmd)) jmetric = os.path.join(datafiles_path, 'bernoulli.metric.json') args = SamplerArgs(metric=jmetric) args.validate(chains=4) cmd = args.compose(1, cmd=[]) self.assertIn('metric=diag_e', ' '.join(cmd)) self.assertIn('metric_file=', ' '.join(cmd)) self.assertIn('bernoulli.metric.json', ' '.join(cmd)) jmetric2 = os.path.join(datafiles_path, 'bernoulli.metric-2.json') args = SamplerArgs(metric=[jmetric, jmetric2]) args.validate(chains=2) cmd = args.compose(0, cmd=[]) self.assertIn('bernoulli.metric.json', ' '.join(cmd)) cmd = args.compose(1, cmd=[]) self.assertIn('bernoulli.metric-2.json', ' '.join(cmd)) args = SamplerArgs(metric=[jmetric, jmetric]) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(metric=[jmetric, jmetric2]) with self.assertRaises(ValueError): args.validate(chains=4) args = SamplerArgs(metric='/no/such/path/to.file') with self.assertRaises(ValueError): args.validate(chains=4)
def test_bad(self): args = SamplerArgs(iter_warmup=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(iter_warmup=10, adapt_engaged=False) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(iter_sampling=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(thin=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(max_treedepth=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(step_size=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(step_size=[1.0, 1.1]) with self.assertRaises(ValueError): args.validate(chains=1) args = SamplerArgs(step_size=[1.0, -1.1]) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_delta=1.1) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_delta=-0.1) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(iter_warmup=100, fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(save_warmup=True, fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(max_treedepth=12, fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(metric='dense', fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(step_size=0.5, fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_delta=0.88, adapt_engaged=False) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_init_phase=0.88) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_metric_window=0.88) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_step_size=0.88) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_init_phase=-1) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_metric_window=-2) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_step_size=-3) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_delta=0.88, fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2)
def test_bad(self): args = SamplerArgs(warmup_iters=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(warmup_iters=0, adapt_engaged=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(sampling_iters=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(thin=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(max_treedepth=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(step_size=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(step_size=[1.0, 1.1]) with self.assertRaises(ValueError): args.validate(chains=1) args = SamplerArgs(step_size=[1.0, -1.1]) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_delta=1.1) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_delta=-0.1) with self.assertRaises(ValueError): args.validate(chains=2)