예제 #1
0
    def test_args_good(self):
        exe = os.path.join(datafiles_path, 'bernoulli')
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        sampler_args = SamplerArgs()

        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            data=jdata,
            method_args=sampler_args,
        )
        self.assertEqual(cmdstan_args.method, Method.SAMPLE)
        cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv')
        self.assertIn('id=1 random seed=', ' '.join(cmd))
        self.assertIn('data file=', ' '.join(cmd))
        self.assertIn('output file=', ' '.join(cmd))
        self.assertIn('method=sample algorithm=hmc', ' '.join(cmd))

        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[7, 11, 18, 29],
            data=jdata,
            method_args=sampler_args,
        )
        cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv')
        self.assertIn('id=7 random seed=', ' '.join(cmd))
예제 #2
0
    def test_args_inits(self):
        exe = os.path.join(datafiles_path, 'bernoulli')
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        sampler_args = SamplerArgs()

        jinits = os.path.join(datafiles_path, 'bernoulli.init.json')
        jinits1 = os.path.join(datafiles_path, 'bernoulli.init_1.json')
        jinits2 = os.path.join(datafiles_path, 'bernoulli.init_2.json')

        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            data=jdata,
            inits=jinits,
            method_args=sampler_args,
        )
        cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv')
        self.assertIn('init=', ' '.join(cmd))

        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2],
            data=jdata,
            inits=[jinits1, jinits2],
            method_args=sampler_args,
        )
        cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv')
        self.assertIn('bernoulli.init_1.json', ' '.join(cmd))
        cmd = cmdstan_args.compose_command(idx=1, csv_file='bern-output-1.csv')
        self.assertIn('bernoulli.init_2.json', ' '.join(cmd))

        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            data=jdata,
            inits=0,
            method_args=sampler_args,
        )
        cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv')
        self.assertIn('init=0', ' '.join(cmd))

        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            data=jdata,
            inits=3.33,
            method_args=sampler_args,
        )
        cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv')
        self.assertIn('init=3.33', ' '.join(cmd))
예제 #3
0
 def test_compose(self):
     exe = os.path.join(datafiles_path, 'bernoulli')
     sampler_args = SamplerArgs()
     cmdstan_args = CmdStanArgs(
         model_name='bernoulli',
         model_exe=exe,
         chain_ids=[1, 2, 3, 4],
         method_args=sampler_args,
     )
     with self.assertRaises(ValueError):
         cmdstan_args.compose_command(idx=4, csv_file='foo')
     with self.assertRaises(ValueError):
         cmdstan_args.compose_command(idx=-1, csv_file='foo')
예제 #4
0
 def __init__(self,
              args: CmdStanArgs,
              chains: int = 4,
              logger: logging.Logger = None) -> None:
     """Initialize object."""
     self._args = args
     self._chains = chains
     self._logger = logger or get_logger()
     if chains < 1:
         raise ValueError('chains must be positive integer value, '
                          'found {i]}'.format(chains))
     self._csv_files = []
     if args.output_basename is None:
         csv_basename = 'stan-{}-{}'.format(args.model_name, args.method)
         for i in range(chains):
             fd_name = create_named_text_file(
                 dir=TMPDIR,
                 prefix='{}-{}-'.format(csv_basename, i + 1),
                 suffix='.csv',
             )
             self._csv_files.append(fd_name)
     else:
         for i in range(chains):
             self._csv_files.append('{}-{}.csv'.format(
                 args.output_basename, i + 1))
     self._console_files = []
     for i in range(chains):
         txt_file = ''.join(
             [os.path.splitext(self._csv_files[i])[0], '.txt'])
         self._console_files.append(txt_file)
     self._cmds = [
         args.compose_command(i, self._csv_files[i]) for i in range(chains)
     ]
     self._retcodes = [-1 for _ in range(chains)]
예제 #5
0
    def test_no_chains(self):
        # we don't have chains for optimize
        exe = os.path.join(datafiles_path, 'bernoulli')
        sampler_args = FixedParamArgs()
        jinits = os.path.join(datafiles_path, 'bernoulli.init.json')
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=None,
            inits=jinits,
            method_args=sampler_args,
        )
        self.assertIn('init=', cmdstan_args.compose_command(None, 'out.csv'))

        with self.assertRaises(ValueError):
            CmdStanArgs(
                model_name='bernoulli',
                model_exe=exe,
                chain_ids=None,
                seed=[1, 2, 3],
                inits=jinits,
                method_args=sampler_args,
            )

        with self.assertRaises(ValueError):
            CmdStanArgs(
                model_name='bernoulli',
                model_exe=exe,
                chain_ids=None,
                inits=[jinits],
                method_args=sampler_args,
            )
예제 #6
0
    def test_args_good(self):
        exe = os.path.join(DATAFILES_PATH, 'bernoulli')
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs()

        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            data=jdata,
            method_args=sampler_args,
            refresh=10,
        )
        self.assertEqual(cmdstan_args.method, Method.SAMPLE)
        cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv')
        self.assertIn('id=1 random seed=', ' '.join(cmd))
        self.assertIn('data file=', ' '.join(cmd))
        self.assertIn('output file=', ' '.join(cmd))
        self.assertIn('method=sample algorithm=hmc', ' '.join(cmd))
        self.assertIn('refresh=10', ' '.join(cmd))

        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[7, 11, 18, 29],
            data=jdata,
            method_args=sampler_args,
        )
        cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv')
        self.assertIn('id=7 random seed=', ' '.join(cmd))

        dirname = 'tmp' + str(time())
        if os.path.exists(dirname):
            os.rmdir(dirname)
        CmdStanArgs(
            model_name='bernoulli',
            model_exe='bernoulli.exe',
            chain_ids=[1, 2, 3, 4],
            output_dir=dirname,
            method_args=sampler_args,
        )
        self.assertTrue(os.path.exists(dirname))
        os.rmdir(dirname)
예제 #7
0
파일: stanfit.py 프로젝트: clonyjr/prophet
 def __init__(self,
              args: CmdStanArgs,
              chains: int = 4,
              logger: logging.Logger = None) -> None:
     """Initialize object."""
     self._args = args
     self._is_optimizing = isinstance(self._args.method_args, OptimizeArgs)
     self._chains = chains
     self._logger = logger or get_logger()
     if chains < 1:
         raise ValueError('chains must be positive integer value, '
                          'found {i]}'.format(chains))
     self.csv_files = []
     """per-chain sample csv files."""
     if args.output_basename is None:
         csv_basename = 'stan-{}-draws'.format(args.model_name)
         for i in range(chains):
             fd = tempfile.NamedTemporaryFile(
                 mode='w+',
                 prefix='{}-{}-'.format(csv_basename, i + 1),
                 suffix='.csv',
                 dir=TMPDIR,
                 delete=False,
             )
             self.csv_files.append(fd.name)
     else:
         for i in range(chains):
             self.csv_files.append('{}-{}.csv'.format(
                 args.output_basename, i + 1))
     self.console_files = []
     """per-chain sample console output files."""
     for i in range(chains):
         txt_file = ''.join(
             [os.path.splitext(self.csv_files[i])[0], '.txt'])
         self.console_files.append(txt_file)
     self.cmds = [
         args.compose_command(i, self.csv_files[i]) for i in range(chains)
     ]
     """per-chain sampler command."""
     self._retcodes = [-1 for _ in range(chains)]
     self._draws = None
     self._column_names = None
     self._num_params = None  # metric dim(s)
     self._metric_type = None
     self._metric = None
     self._stepsize = None
     self._sample = None
     self._first_draw = None
예제 #8
0
 def test_args_sig_figs(self):
     sampler_args = SamplerArgs()
     cmdstan_path()  # sets os.environ['CMDSTAN']
     if cmdstan_version_before(2, 25):
         with LogCapture() as log:
             logging.getLogger()
             CmdStanArgs(
                 model_name='bernoulli',
                 model_exe='bernoulli.exe',
                 chain_ids=[1, 2, 3, 4],
                 sig_figs=12,
                 method_args=sampler_args,
             )
         expect = (
             'Argument "sig_figs" invalid for CmdStan versions < 2.25, '
             'using version {} in directory {}').format(
                 os.path.basename(cmdstan_path()),
                 os.path.dirname(cmdstan_path()),
             )
         log.check_present(('cmdstanpy', 'WARNING', expect))
     else:
         cmdstan_args = CmdStanArgs(
             model_name='bernoulli',
             model_exe='bernoulli.exe',
             chain_ids=[1, 2, 3, 4],
             sig_figs=12,
             method_args=sampler_args,
         )
         cmd = cmdstan_args.compose_command(idx=0,
                                            csv_file='bern-output-1.csv')
         self.assertIn('sig_figs=', ' '.join(cmd))
         with self.assertRaises(ValueError):
             CmdStanArgs(
                 model_name='bernoulli',
                 model_exe='bernoulli.exe',
                 chain_ids=[1, 2, 3, 4],
                 sig_figs=-1,
                 method_args=sampler_args,
             )
         with self.assertRaises(ValueError):
             CmdStanArgs(
                 model_name='bernoulli',
                 model_exe='bernoulli.exe',
                 chain_ids=[1, 2, 3, 4],
                 sig_figs=20,
                 method_args=sampler_args,
             )
예제 #9
0
    def __init__(self,
                 args: CmdStanArgs,
                 chains: int = 4,
                 logger: logging.Logger = None) -> None:
        """Initialize object."""
        self._args = args
        self._chains = chains
        self._logger = logger or get_logger()
        if chains < 1:
            raise ValueError('chains must be positive integer value, '
                             'found {}'.format(chains))

        self._retcodes = [-1 for _ in range(chains)]

        # output and console messages are written to a text file:
        # ``<model_name>-<YYYYMMDDHHMM>-<chain_id>.txt``
        now = datetime.now()
        now_str = now.strftime('%Y%m%d%H%M')
        file_basename = '-'.join([args.model_name, now_str])
        if args.output_dir is not None:
            output_dir = args.output_dir
        else:
            output_dir = TMPDIR

        self._csv_files = []
        self._diagnostic_files = [None for _ in range(chains)]
        self._console_files = []
        self._cmds = []
        for i in range(chains):
            if args.output_dir is None:
                csv_file = create_named_text_file(
                    dir=output_dir,
                    prefix='{}-{}-'.format(file_basename, i + 1),
                    suffix='.csv',
                )
            else:
                csv_file = os.path.join(
                    output_dir, '{}-{}.{}'.format(file_basename, i + 1, 'csv'))
            self._csv_files.append(csv_file)
            txt_file = ''.join([os.path.splitext(csv_file)[0], '.txt'])
            self._console_files.append(txt_file)
            if args.save_diagnostics:
                if args.output_dir is None:
                    diag_file = create_named_text_file(
                        dir=TMPDIR,
                        prefix='{}-diagnostic-{}-'.format(
                            file_basename, i + 1),
                        suffix='.csv',
                    )
                else:
                    diag_file = os.path.join(
                        output_dir,
                        '{}-diagnostic-{}.{}'.format(file_basename, i + 1,
                                                     'csv'),
                    )
                self._diagnostic_files.append(diag_file)
                self._cmds.append(
                    args.compose_command(i, self._csv_files[i],
                                         self._diagnostic_files[i]))
            else:
                self._cmds.append(args.compose_command(i, self._csv_files[i]))
예제 #10
0
    def __init__(
        self,
        args: CmdStanArgs,
        chains: int = 4,
        chain_ids: List[int] = None,
        logger: logging.Logger = None,
    ) -> None:
        """Initialize object."""
        self._args = args
        self._chains = chains
        self._logger = logger or get_logger()
        if chains < 1:
            raise ValueError('chains must be positive integer value, '
                             'found {}'.format(chains))
        if chain_ids is None:
            chain_ids = [x + 1 for x in range(chains)]
        elif len(chain_ids) != chains:
            raise ValueError(
                'mismatch between number of chains and chain_ids, '
                'found {} chains, but {} chain_ids'.format(
                    chains, len(chain_ids)))
        self._chain_ids = chain_ids
        self._retcodes = [-1 for _ in range(chains)]

        # stdout, stderr are written to text files
        # prefix: ``<model_name>-<YYYYMMDDHHMM>-<chain_id>``
        # suffixes: ``-stdout.txt``, ``-stderr.txt``
        now = datetime.now()
        now_str = now.strftime('%Y%m%d%H%M')
        file_basename = '-'.join([args.model_name, now_str])
        if args.output_dir is not None:
            output_dir = args.output_dir
        else:
            output_dir = _TMPDIR
        self._csv_files = [None for _ in range(chains)]
        self._diagnostic_files = [None for _ in range(chains)]
        self._stdout_files = [None for _ in range(chains)]
        self._stderr_files = [None for _ in range(chains)]
        self._cmds = []
        for i in range(chains):
            if args.output_dir is None:
                csv_file = create_named_text_file(
                    dir=output_dir,
                    prefix='{}-{}-'.format(file_basename, str(chain_ids[i])),
                    suffix='.csv',
                )
            else:
                csv_file = os.path.join(
                    output_dir,
                    '{}-{}.{}'.format(file_basename, str(chain_ids[i]), 'csv'),
                )
            self._csv_files[i] = csv_file
            stdout_file = ''.join(
                [os.path.splitext(csv_file)[0], '-stdout.txt'])
            self._stdout_files[i] = stdout_file
            stderr_file = ''.join(
                [os.path.splitext(csv_file)[0], '-stderr.txt'])
            self._stderr_files[i] = stderr_file
            if args.save_diagnostics:
                if args.output_dir is None:
                    diag_file = create_named_text_file(
                        dir=_TMPDIR,
                        prefix='{}-diagnostic-{}-'.format(
                            file_basename, str(chain_ids[i])),
                        suffix='.csv',
                    )
                else:
                    diag_file = os.path.join(
                        output_dir,
                        '{}-diagnostic-{}.{}'.format(file_basename,
                                                     str(chain_ids[i]), 'csv'),
                    )
                self._diagnostic_files[i] = diag_file
                self._cmds.append(
                    args.compose_command(i, self._csv_files[i],
                                         self._diagnostic_files[i]))
            else:
                self._cmds.append(args.compose_command(i, self._csv_files[i]))