def test_args_variational(self): args = VariationalArgs() self.assertTrue(True) args = VariationalArgs(output_samples=1) args.validate(chains=1) cmd = args.compose(idx=0, cmd=[]) self.assertIn('method=variational', ' '.join(cmd)) self.assertIn('output_samples=1', ' '.join(cmd)) args = VariationalArgs(tol_rel_obj=0.01) args.validate(chains=1) cmd = args.compose(idx=0, cmd=[]) self.assertIn('method=variational', ' '.join(cmd)) self.assertIn('tol_rel_obj=0.01', ' '.join(cmd)) args = VariationalArgs(adapt_engaged=True, adapt_iter=100) args.validate(chains=1) cmd = args.compose(idx=0, cmd=[]) self.assertIn('adapt engaged=1 iter=100', ' '.join(cmd)) args = VariationalArgs(adapt_engaged=False) args.validate(chains=1) cmd = args.compose(idx=0, cmd=[]) self.assertIn('adapt engaged=0', ' '.join(cmd)) args = VariationalArgs(eta=0.1) args.validate(chains=1) cmd = args.compose(idx=0, cmd=[]) self.assertIn('eta=0.1', ' '.join(cmd))
def test_args_variational(self): args = VariationalArgs() self.assertTrue(True) args = VariationalArgs(output_samples=1) args.validate(chains=1) cmd = args.compose(idx=0, cmd=[]) self.assertIn('method=variational', ' '.join(cmd)) self.assertIn('output_samples=1', ' '.join(cmd)) args = VariationalArgs(tol_rel_obj=1) args.validate(chains=1) cmd = args.compose(idx=0, cmd=[]) self.assertIn('method=variational', ' '.join(cmd)) self.assertIn('tol_rel_obj=1', ' '.join(cmd))
def test_args_bad(self): args = VariationalArgs(algorithm='no_such_algo') with self.assertRaises(ValueError): args.validate() args = VariationalArgs(iter=0) with self.assertRaises(ValueError): args.validate() args = VariationalArgs(iter=1.1) with self.assertRaises(ValueError): args.validate() args = VariationalArgs(grad_samples=0) with self.assertRaises(ValueError): args.validate() args = VariationalArgs(grad_samples=1.1) with self.assertRaises(ValueError): args.validate() args = VariationalArgs(elbo_samples=0) with self.assertRaises(ValueError): args.validate() args = VariationalArgs(elbo_samples=1.1) with self.assertRaises(ValueError): args.validate() args = VariationalArgs(eta=-0.00003) with self.assertRaises(ValueError): args.validate() args = VariationalArgs(adapt_iter=0) with self.assertRaises(ValueError): args.validate() args = VariationalArgs(adapt_iter=1.1) with self.assertRaises(ValueError): args.validate() args = VariationalArgs(tol_rel_obj=0) with self.assertRaises(ValueError): args.validate() args = VariationalArgs(eval_elbo=0) with self.assertRaises(ValueError): args.validate() args = VariationalArgs(eval_elbo=1.5) with self.assertRaises(ValueError): args.validate() args = VariationalArgs(output_samples=0) with self.assertRaises(ValueError): args.validate()