def test_good(self):
        args = SamplerArgs(
            warmup_iters=10,
            sampling_iters=20,
            save_warmup=True,
            thin=7,
            max_treedepth=15,
            adapt_delta=0.99,
        )
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertIn('method=sample', ' '.join(cmd))
        self.assertIn('num_warmup=10', ' '.join(cmd))
        self.assertIn('num_samples=20', ' '.join(cmd))
        self.assertIn('save_warmup=1', ' '.join(cmd))
        self.assertIn('thin=7', ' '.join(cmd))
        self.assertIn('algorithm=hmc engine=nuts', ' '.join(cmd))
        self.assertIn('max_depth=15 adapt delta=0.99', ' '.join(cmd))

        args = SamplerArgs(warmup_iters=10)
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertIn('method=sample', ' '.join(cmd))
        self.assertIn('num_warmup=10', ' '.join(cmd))
        self.assertNotIn('num_samples=', ' '.join(cmd))
        self.assertNotIn('save_warmup=', ' '.join(cmd))
        self.assertNotIn('algorithm=hmc engine=nuts', ' '.join(cmd))
Example #2
0
    def test_adapt(self):
        args = SamplerArgs(adapt_engaged=False)
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertIn('method=sample algorithm=hmc adapt engaged=0',
                      ' '.join(cmd))

        args = SamplerArgs(adapt_engaged=True)
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertIn('method=sample algorithm=hmc adapt engaged=1',
                      ' '.join(cmd))

        args = SamplerArgs(adapt_init_phase=26,
                           adapt_metric_window=60,
                           adapt_step_size=34)
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertIn('method=sample algorithm=hmc adapt', ' '.join(cmd))
        self.assertIn('init_buffer=26', ' '.join(cmd))
        self.assertIn('window=60', ' '.join(cmd))
        self.assertIn('term_buffer=34', ' '.join(cmd))

        args = SamplerArgs()
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertNotIn('engine=nuts', ' '.join(cmd))
        self.assertNotIn('adapt engaged=0', ' '.join(cmd))
    def test_adapt(self):
        args = SamplerArgs(adapt_engaged=False)
        args.validate(chains=4)
        cmd = args.compose(1, '')
        self.assertIn('method=sample algorithm=hmc adapt engaged=0', cmd)

        args = SamplerArgs(adapt_engaged=True)
        args.validate(chains=4)
        cmd = args.compose(1, '')
        self.assertIn('method=sample algorithm=hmc adapt engaged=1', cmd)

        args = SamplerArgs()
        args.validate(chains=4)
        cmd = args.compose(1, '')
        self.assertNotIn('engine=nuts', cmd)
        self.assertNotIn('engaged=1', cmd)
    def test_metric(self):
        args = SamplerArgs(metric='dense_e')
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertIn('method=sample algorithm=hmc metric=dense_e',
                      ' '.join(cmd))

        args = SamplerArgs(metric='dense')
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertIn('method=sample algorithm=hmc metric=dense_e',
                      ' '.join(cmd))

        args = SamplerArgs(metric='diag_e')
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertIn('method=sample algorithm=hmc metric=diag_e',
                      ' '.join(cmd))

        args = SamplerArgs(metric='diag')
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertIn('method=sample algorithm=hmc metric=diag_e',
                      ' '.join(cmd))

        args = SamplerArgs()
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertNotIn('metric=', ' '.join(cmd))

        jmetric = os.path.join(datafiles_path, 'bernoulli.metric.json')
        args = SamplerArgs(metric=jmetric)
        args.validate(chains=4)
        cmd = args.compose(1, cmd=[])
        self.assertIn('metric=diag_e', ' '.join(cmd))
        self.assertIn('metric_file=', ' '.join(cmd))
        self.assertIn('bernoulli.metric.json', ' '.join(cmd))

        jmetric2 = os.path.join(datafiles_path, 'bernoulli.metric-2.json')
        args = SamplerArgs(metric=[jmetric, jmetric2])
        args.validate(chains=2)
        cmd = args.compose(0, cmd=[])
        self.assertIn('bernoulli.metric.json', ' '.join(cmd))
        cmd = args.compose(1, cmd=[])
        self.assertIn('bernoulli.metric-2.json', ' '.join(cmd))

        args = SamplerArgs(metric=[jmetric, jmetric])
        with self.assertRaises(ValueError):
            args.validate(chains=2)

        args = SamplerArgs(metric=[jmetric, jmetric2])
        with self.assertRaises(ValueError):
            args.validate(chains=4)

        args = SamplerArgs(metric='/no/such/path/to.file')
        with self.assertRaises(ValueError):
            args.validate(chains=4)
 def test_args_min(self):
     args = SamplerArgs()
     args.validate(chains=4)
     cmd = args.compose(idx=1, cmd=[])
     self.assertIn('method=sample algorithm=hmc', ' '.join(cmd))
 def test_fixed_param(self):
     args = SamplerArgs(fixed_param=True)
     args.validate(chains=1)
     cmd = args.compose(0, cmd=[])
     self.assertIn('method=sample algorithm=fixed_param', ' '.join(cmd))