Exemple #1
0
    def validate(self, chains: int) -> None:
        """
        Check arguments correctness and consistency.

        * adaptation and warmup args are consistent
        * if file(s) for metric are supplied, check contents.
        * length of per-chain lists equals specified # of chains
        """
        if not isinstance(chains, Integral) or chains < 1:
            raise ValueError(
                'sampler expects number of chains to be greater than 0')
        if not (self.adapt_delta is None and self.adapt_init_phase is None
                and self.adapt_metric_window is None
                and self.adapt_step_size is None):
            if self.adapt_engaged is False:
                msg = 'conflicting arguments: adapt_engaged: False'
                if self.adapt_delta is not None:
                    msg = '{}, adapt_delta: {}'.format(msg, self.adapt_delta)
                if self.adapt_init_phase is not None:
                    msg = '{}, adapt_init_phase: {}'.format(
                        msg, self.adapt_init_phase)
                if self.adapt_metric_window is not None:
                    msg = '{}, adapt_metric_window: {}'.format(
                        msg, self.adapt_metric_window)
                if self.adapt_step_size is not None:
                    msg = '{}, adapt_step_size: {}'.format(
                        msg, self.adapt_step_size)
                raise ValueError(msg)

        if self.iter_warmup is not None:
            if self.iter_warmup < 0 or not isinstance(self.iter_warmup,
                                                      Integral):
                raise ValueError('iter_warmup must be a non-negative integer,'
                                 ' found {}'.format(self.iter_warmup))
            if self.iter_warmup > 0 and not self.adapt_engaged:
                raise ValueError(
                    'adapt_engaged is False, cannot specify warmup iterations')
        if self.iter_sampling is not None:
            if self.iter_sampling < 0 or not isinstance(
                    self.iter_sampling, Integral):
                raise ValueError(
                    'iter_sampling must be a non-negative integer,'
                    ' found {}'.format(self.iter_sampling))
        if self.thin is not None:
            if self.thin < 1 or not isinstance(self.thin, Integral):
                raise ValueError('thin must be a positive integer,'
                                 'found {}'.format(self.thin))
        if self.max_treedepth is not None:
            if self.max_treedepth < 1 or not isinstance(
                    self.max_treedepth, Integral):
                raise ValueError('max_treedepth must be a positive integer,'
                                 ' found {}'.format(self.max_treedepth))
        if self.step_size is not None:
            if isinstance(self.step_size, Real):
                if self.step_size <= 0:
                    raise ValueError('step_size must be > 0, found {}'.format(
                        self.step_size))
            else:
                if len(self.step_size) != chains:
                    raise ValueError(
                        'number of step_sizes must match number of chains,'
                        ' found {} step_sizes for {} chains'.format(
                            len(self.step_size), chains))
                for step_size in self.step_size:
                    if step_size < 0:
                        raise ValueError(
                            'step_size must be > 0, found {}'.format(
                                step_size))
        if self.metric is not None:
            dims = []
            if isinstance(self.metric, str):
                if self.metric in ['diag', 'diag_e']:
                    self.metric = 'diag_e'
                elif self.metric in ['dense', 'dense_e']:
                    self.metric = 'dense_e'
                else:
                    if not os.path.exists(self.metric):
                        raise ValueError('no such file {}'.format(self.metric))
                    dims = read_metric(self.metric)
            elif isinstance(self.metric, (list, tuple)):
                if len(self.metric) != chains:
                    raise ValueError(
                        'number of metric files must match number of chains,'
                        ' found {} metric files for {} chains'.format(
                            len(self.metric), chains))
                names_set = set(self.metric)
                if len(names_set) != len(self.metric):
                    raise ValueError(
                        'each chain must have its own metric file,'
                        ' found duplicates in metric files list.')
                for i, metric in enumerate(self.metric):
                    if not os.path.exists(metric):
                        raise ValueError('no such file {}'.format(metric))
                    if i == 0:
                        dims = read_metric(metric)
                    else:
                        dims2 = read_metric(metric)
                        if len(dims) != len(dims2):
                            raise ValueError('metrics files {}, {},'
                                             ' inconsistent metrics'.format(
                                                 self.metric[0], metric))
                        for dim, dim2 in zip(dims, dims2):
                            if dim != dim2:
                                raise ValueError(
                                    'metrics files {}, {},'
                                    ' inconsistent metrics'.format(
                                        self.metric[0], metric))
            if any(dims):
                if len(dims) > 2 or (len(dims) == 2 and dims[0] != dims[1]):
                    raise ValueError('bad metric specifiation')
                self.metric_file = self.metric
                if len(dims) == 1:
                    self.metric = 'diag_e'
                elif len(dims) == 2:
                    self.metric = 'dense_e'
        if self.adapt_delta is not None:
            if not 0 < self.adapt_delta < 1:
                raise ValueError('adapt_delta must be between 0 and 1,'
                                 ' found {}'.format(self.adapt_delta))
        if self.adapt_init_phase is not None:
            if self.adapt_init_phase < 0 or not isinstance(
                    self.adapt_init_phase, Integral):
                raise ValueError(
                    'adapt_init_phase must be a non-negative integer,'
                    'found {}'.format(self.adapt_init_phase))
        if self.adapt_metric_window is not None:
            if self.adapt_metric_window < 0 or not isinstance(
                    self.adapt_metric_window, Integral):
                raise ValueError(
                    'adapt_metric_window must be a non-negative integer,'
                    'found {}'.format(self.adapt_metric_window))
        if self.adapt_step_size is not None:
            if self.adapt_step_size < 0 or not isinstance(
                    self.adapt_step_size, Integral):
                raise ValueError(
                    'adapt_step_size must be a non-negative integer,'
                    'found {}'.format(self.adapt_step_size))

        if self.fixed_param and (
            (self.iter_warmup is not None and self.iter_warmup > 0)
                or self.save_warmup or self.max_treedepth is not None
                or self.metric is not None or self.step_size is not None or
                not (self.adapt_delta is None and self.adapt_init_phase is None
                     and self.adapt_metric_window is None
                     and self.adapt_step_size is None)):
            raise ValueError('when fixed_param=True, cannot specify warmup'
                             ' or adaptation parameters.')
Exemple #2
0
    def validate(self, chains: int) -> None:
        """
        Check arguments correctness and consistency.

        * adaptation and warmup args are consistent
        * if file(s) for metric are supplied, check contents.
        * length of per-chain lists equals specified # of chains
        """
        if not isinstance(chains, Integral) or chains < 1:
            raise ValueError(
                'sampler expects number of chains to be greater than 0'
            )

        if self.warmup_iters is not None:
            if self.warmup_iters < 0:
                raise ValueError(
                    'warmup_iters must be a non-negative integer'.format(
                        self.warmup_iters
                    )
                )
            if self.adapt_engaged and self.warmup_iters == 0:
                raise ValueError(
                    'adaptation requested but 0 warmup iterations specified, '
                    'must run warmup iterations'
                )

        if self.sampling_iters is not None:
            if self.sampling_iters < 0:
                raise ValueError(
                    'sampling_iters must be a non-negative integer'.format(
                        self.sampling_iters
                    )
                )

        if self.thin is not None:
            if self.thin < 1:
                raise ValueError(
                    'thin must be at least 1, found {}'.format(self.thin)
                )

        if self.max_treedepth is not None:
            if self.max_treedepth < 1:
                raise ValueError(
                    'max_treedepth must be at least 1, found {}'.format(
                        self.max_treedepth
                    )
                )

        if self.step_size is not None:
            if isinstance(self.step_size, Real):
                if self.step_size < 0:
                    raise ValueError(
                        'step_size must be > 0, found {}'.format(self.step_size)
                    )
            else:
                if len(self.step_size) != chains:
                    raise ValueError(
                        'number of step_sizes must match number of chains '
                        ' found {} step_sizes for {} chains '.format(
                            len(self.step_size), chains
                        )
                    )
                for step_size in self.step_size:
                    if step_size < 0:
                        raise ValueError(
                            'step_size must be > 0, found {}'.format(step_size)
                        )

        if self.metric is not None:
            dims = None
            if isinstance(self.metric, str):
                if self.metric in ['diag', 'diag_e']:
                    self.metric = 'diag_e'
                elif self.metric in ['dense', 'dense_e']:
                    self.metric = 'dense_e'
                else:
                    if not os.path.exists(self.metric):
                        raise ValueError('no such file {}'.format(self.metric))
                    dims = read_metric(self.metric)
            elif isinstance(self.metric, (list, tuple)):
                if len(self.metric) != chains:
                    raise ValueError(
                        'number of metric files must match number of chains '
                        ' found {} metric files for {} chains '.format(
                            len(self.metric), chains
                        )
                    )
                names_set = set(self.metric)
                if len(names_set) != len(self.metric):
                    raise ValueError(
                        'each chain must have its own metric file,'
                        ' found duplicates in metric files list.'
                    )
                for i, metric in enumerate(self.metric):
                    if not os.path.exists(metric):
                        raise ValueError('no such file {}'.format(metric))
                    if i == 0:
                        dims = read_metric(metric)
                    else:
                        dims2 = read_metric(metric)
                        if len(dims) != len(dims2):
                            raise ValueError(
                                'metrics files {}, {},'
                                ' inconsistent metrics'.format(
                                    self.metric[0], metric
                                )
                            )
                        for j, dim in enumerate(dims):
                            if dims[j] != dims2[j]:
                                raise ValueError(
                                    'metrics files {}, {},'
                                    ' inconsistent metrics'.format(
                                        self.metric[0], metric
                                    )
                                )
            if dims is not None:
                if len(dims) > 2 or (len(dims) == 2 and dims[0] != dims[1]):
                    raise ValueError('bad metric specifiation')
                self.metric_file = self.metric
                if len(dims) == 1:
                    self.metric = 'diag_e'
                elif len(dims) == 2:
                    self.metric = 'dense_e'

        if self.adapt_delta is not None:
            if not 0 < self.adapt_delta < 1:
                raise ValueError(
                    'adapt_delta must be between 0 and 1,'
                    ' found {}'.format(self.adapt_delta)
                )
Exemple #3
0
    def validate(self, chains: Optional[int]) -> None:
        """
        Check arguments correctness and consistency.

        * adaptation and warmup args are consistent
        * if file(s) for metric are supplied, check contents.
        * length of per-chain lists equals specified # of chains
        """
        if not isinstance(chains, int) or chains < 1:
            raise ValueError(
                'Sampler expects number of chains to be greater than 0.')
        if not (self.adapt_delta is None and self.adapt_init_phase is None
                and self.adapt_metric_window is None
                and self.adapt_step_size is None):
            if self.adapt_engaged is False:
                msg = 'Conflicting arguments: adapt_engaged: False'
                if self.adapt_delta is not None:
                    msg = '{}, adapt_delta: {}'.format(msg, self.adapt_delta)
                if self.adapt_init_phase is not None:
                    msg = '{}, adapt_init_phase: {}'.format(
                        msg, self.adapt_init_phase)
                if self.adapt_metric_window is not None:
                    msg = '{}, adapt_metric_window: {}'.format(
                        msg, self.adapt_metric_window)
                if self.adapt_step_size is not None:
                    msg = '{}, adapt_step_size: {}'.format(
                        msg, self.adapt_step_size)
                raise ValueError(msg)

        if self.iter_warmup is not None:
            if self.iter_warmup < 0 or not isinstance(self.iter_warmup, int):
                raise ValueError(
                    'Value for iter_warmup must be a non-negative integer,'
                    ' found {}.'.format(self.iter_warmup))
            if self.iter_warmup > 0 and not self.adapt_engaged:
                raise ValueError('Argument "adapt_engaged" is False, '
                                 'cannot specify warmup iterations.')
        if self.iter_sampling is not None:
            if self.iter_sampling < 0 or not isinstance(
                    self.iter_sampling, int):
                raise ValueError(
                    'Argument "iter_sampling" must be a non-negative integer,'
                    ' found {}.'.format(self.iter_sampling))
        if self.thin is not None:
            if self.thin < 1 or not isinstance(self.thin, int):
                raise ValueError('Argument "thin" must be a positive integer,'
                                 'found {}.'.format(self.thin))
        if self.max_treedepth is not None:
            if self.max_treedepth < 1 or not isinstance(
                    self.max_treedepth, int):
                raise ValueError(
                    'Argument "max_treedepth" must be a positive integer,'
                    ' found {}.'.format(self.max_treedepth))
        if self.step_size is not None:
            if isinstance(self.step_size, (float, int)):
                if self.step_size <= 0:
                    raise ValueError('Argument "step_size" must be > 0, '
                                     'found {}.'.format(self.step_size))
            else:
                if len(self.step_size) != chains:
                    raise ValueError(
                        'Expecting {} per-chain step_size specifications, '
                        ' found {}.'.format(chains, len(self.step_size)))
                for i, step_size in enumerate(self.step_size):
                    if step_size < 0:
                        raise ValueError('Argument "step_size" must be > 0, '
                                         'chain {}, found {}.'.format(
                                             i + 1, step_size))
        if self.metric is not None:
            if isinstance(self.metric, str):
                if self.metric in ['diag', 'diag_e']:
                    self.metric_type = 'diag_e'
                elif self.metric in ['dense', 'dense_e']:
                    self.metric_type = 'dense_e'
                elif self.metric in ['unit', 'unit_e']:
                    self.metric_type = 'unit_e'
                else:
                    if not os.path.exists(self.metric):
                        raise ValueError('no such file {}'.format(self.metric))
                    dims = read_metric(self.metric)
                    if len(dims) == 1:
                        self.metric_type = 'diag_e'
                    else:
                        self.metric_type = 'dense_e'
                    self.metric_file = self.metric
            elif isinstance(self.metric, Dict):
                if 'inv_metric' not in self.metric:
                    raise ValueError(
                        'Entry "inv_metric" not found in metric dict.')
                dims = list(np.asarray(self.metric['inv_metric']).shape)
                if len(dims) == 1:
                    self.metric_type = 'diag_e'
                else:
                    self.metric_type = 'dense_e'
                dict_file = create_named_text_file(dir=_TMPDIR,
                                                   prefix="metric",
                                                   suffix=".json")
                write_stan_json(dict_file, self.metric)
                self.metric_file = dict_file
            elif isinstance(self.metric, (list, tuple)):
                if len(self.metric) != chains:
                    raise ValueError(
                        'Number of metric files must match number of chains,'
                        ' found {} metric files for {} chains.'.format(
                            len(self.metric), chains))
                if all(isinstance(elem, dict) for elem in self.metric):
                    metric_files: List[str] = []
                    for i, metric in enumerate(self.metric):
                        assert isinstance(metric,
                                          dict)  # make the typechecker happy
                        metric_dict: Dict[str, Any] = metric
                        if 'inv_metric' not in metric_dict:
                            raise ValueError(
                                'Entry "inv_metric" not found in metric dict '
                                'for chain {}.'.format(i + 1))
                        if i == 0:
                            dims = list(
                                np.asarray(metric_dict['inv_metric']).shape)
                        else:
                            dims2 = list(
                                np.asarray(metric_dict['inv_metric']).shape)
                            if dims != dims2:
                                raise ValueError(
                                    'Found inconsistent "inv_metric" entry '
                                    'for chain {}: entry has dims '
                                    '{}, expected {}.'.format(
                                        i + 1, dims, dims2))
                        dict_file = create_named_text_file(dir=_TMPDIR,
                                                           prefix="metric",
                                                           suffix=".json")
                        write_stan_json(dict_file, metric_dict)
                        metric_files.append(dict_file)
                    if len(dims) == 1:
                        self.metric_type = 'diag_e'
                    else:
                        self.metric_type = 'dense_e'
                    self.metric_file = metric_files
                elif all(isinstance(elem, str) for elem in self.metric):
                    metric_files = []
                    for i, metric in enumerate(self.metric):
                        assert isinstance(metric, str)  # typecheck
                        if not os.path.exists(metric):
                            raise ValueError('no such file {}'.format(metric))
                        if i == 0:
                            dims = read_metric(metric)
                        else:
                            dims2 = read_metric(metric)
                            if len(dims) != len(dims2):
                                raise ValueError(
                                    'Metrics files {}, {},'
                                    ' inconsistent metrics'.format(
                                        self.metric[0], metric))
                            if dims != dims2:
                                raise ValueError(
                                    'Metrics files {}, {},'
                                    ' inconsistent metrics'.format(
                                        self.metric[0], metric))
                        metric_files.append(metric)
                    if len(dims) == 1:
                        self.metric_type = 'diag_e'
                    else:
                        self.metric_type = 'dense_e'
                    self.metric_file = metric_files
                else:
                    raise ValueError(
                        'Argument "metric" must be a list of pathnames or '
                        'Python dicts, found list of {}.'.format(
                            type(self.metric[0])))
            else:
                raise ValueError(
                    'Invalid metric specified, not a recognized metric type, '
                    'must be either a metric type name, a filepath, dict, '
                    'or list of per-chain filepaths or dicts.  Found '
                    'an object of type {}.'.format(type(self.metric)))

        if self.adapt_delta is not None:
            if not 0 < self.adapt_delta < 1:
                raise ValueError(
                    'Argument "adapt_delta" must be between 0 and 1,'
                    ' found {}'.format(self.adapt_delta))
        if self.adapt_init_phase is not None:
            if self.adapt_init_phase < 0 or not isinstance(
                    self.adapt_init_phase, int):
                raise ValueError(
                    'Argument "adapt_init_phase" must be a non-negative '
                    'integer, found {}'.format(self.adapt_init_phase))
        if self.adapt_metric_window is not None:
            if self.adapt_metric_window < 0 or not isinstance(
                    self.adapt_metric_window, int):
                raise ValueError(
                    'Argument "adapt_metric_window" must be a non-negative '
                    ' integer, found {}'.format(self.adapt_metric_window))
        if self.adapt_step_size is not None:
            if self.adapt_step_size < 0 or not isinstance(
                    self.adapt_step_size, int):
                raise ValueError(
                    'Argument "adapt_step_size" must be a non-negative integer,'
                    'found {}'.format(self.adapt_step_size))

        if self.fixed_param and (
                self.max_treedepth is not None or self.metric is not None
                or self.step_size is not None or
                not (self.adapt_delta is None and self.adapt_init_phase is None
                     and self.adapt_metric_window is None
                     and self.adapt_step_size is None)):
            raise ValueError(
                'When fixed_param=True, cannot specify adaptation parameters.')
Exemple #4
0
 def test_metric_rdump_bad_2(self):
     metric_file = os.path.join(datafiles_path, 'metric_bad_2.data.R')
     with self.assertRaisesRegex(Exception,
                                 'bad or missing entry "inv_metric"'):
         dims = read_metric(metric_file)
Exemple #5
0
 def test_metric_missing(self):
     metric_file = os.path.join(datafiles_path, 'no_such_file.json')
     with self.assertRaisesRegex(Exception, 'No such file or directory'):
         dims = read_metric(metric_file)
Exemple #6
0
 def test_metric_rdump_vec(self):
     metric_file = os.path.join(datafiles_path, 'metric_diag.data.R')
     dims = read_metric(metric_file)
     self.assertEqual(1, len(dims))
     self.assertEqual(3, dims[0])
Exemple #7
0
 def test_metric_rdump_matrix(self):
     metric_file = os.path.join(datafiles_path, 'metric_dense.data.R')
     dims = read_metric(metric_file)
     self.assertEqual(2, len(dims))
     self.assertEqual(dims[0], dims[1])
Exemple #8
0
 def test_metric_missing(self):
     metric_file = os.path.join(DATAFILES_PATH, 'no_such_file.json')
     with self.assertRaisesRegex(Exception, 'No such file or directory'):
         read_metric(metric_file)
Exemple #9
0
 def test_metric_rdump_bad_2(self):
     metric_file = os.path.join(DATAFILES_PATH, 'metric_bad_2.data.R')
     with self.assertRaisesRegex(
         Exception, 'bad or missing entry "inv_metric"'
     ):
         read_metric(metric_file)
Exemple #10
0
 def test_metric_json_matrix(self):
     metric_file = os.path.join(DATAFILES_PATH, 'metric_dense.data.json')
     dims = read_metric(metric_file)
     self.assertEqual(2, len(dims))
     self.assertEqual(dims[0], dims[1])
Exemple #11
0
 def test_metric_json_vec(self):
     metric_file = os.path.join(DATAFILES_PATH, 'metric_diag.data.json')
     dims = read_metric(metric_file)
     self.assertEqual(1, len(dims))
     self.assertEqual(3, dims[0])
Exemple #12
0
    def validate(self) -> None:
        """
        Check arguments correctness and consistency.

        * input files must exist
        * output files must be in a writeable directory
        * adaptation and warmup args are consistent
        * if file(s) for metric are supplied, check contents.
        * if no seed specified, set random seed.
        * length of per-chain lists equals specified # of chains
        """
        if self.model is None:
            raise ValueError('no stan model specified')
        if self.model.exe_file is None:
            raise ValueError('stan model must be compiled first,' +
                             ' run command compile_model("{}")'.format(
                                 self.model.stan_file))
        if not os.path.exists(self.model.exe_file):
            raise ValueError('cannot access model executable "{}"'.format(
                self.model.exe_file))

        if self.chain_ids is not None:
            for i in range(len(self.chain_ids)):
                if self.chain_ids[i] < 1:
                    raise ValueError('invalid chain_id {}'.format(
                        self.chain_ids[i]))

        if self.output_file is not None:
            if not os.path.exists(os.path.dirname(self.output_file)):
                raise ValueError('invalid path for output files: {}'.format(
                    self.output_file))
            try:
                with open(self.output_file, 'w+') as fd:
                    pass
                os.remove(self.output_file)  # cleanup
            except Exception:
                raise ValueError('invalid path for output files: {}'.format(
                    self.output_file))
            if self.output_file.endswith('.csv'):
                self.output_file = self.output_file[:-4]

        if self.seed is None:
            rng = np.random.RandomState()
            self.seed = rng.randint(1, 99999 + 1)
        else:
            if not isinstance(self.seed, (int, list)):
                raise ValueError(
                    'seed must be an integer between 0 and 2**32-1,'
                    ' found {}'.format(self.seed))
            elif isinstance(self.seed, int):
                if self.seed < 0 or self.seed > 2**32 - 1:
                    raise ValueError(
                        'seed must be an integer between 0 and 2**32-1,'
                        ' found {}'.format(self.seed))
            else:
                if len(self.seed) != len(self.chain_ids):
                    raise ValueError(
                        'number of seeds must match number of chains '
                        ' found {} seed for {} chains '.format(
                            len(self.seed), len(self.chain_ids)))
                for i in range(len(self.seed)):
                    if self.seed[i] < 0 or self.seed[i] > 2**32 - 1:
                        raise ValueError('seed must be an integer value'
                                         ' between 0 and 2**32-1,'
                                         ' found {}'.format(self.seed[i]))

        if self.data is not None:
            if not os.path.exists(self.data):
                raise ValueError('no such file {}'.format(self.data))

        if self.inits is not None:
            if isinstance(self.inits, (int, float)):
                if self.inits < 0:
                    raise ValueError('inits must be > 0, found {}'.format(
                        self.inits))
            elif isinstance(self.inits, str):
                if not os.path.exists(self.inits):
                    raise ValueError('no such file {}'.format(self.inits))
            elif isinstance(self.inits, List):
                if len(self.inits) != len(self.chain_ids):
                    raise ValueError(
                        'number of inits files must match number of chains '
                        ' found {} inits files for {} chains '.format(
                            len(self.inits), len(self.chain_ids)))
                names_set = set(self.inits)
                if len(names_set) != len(self.inits):
                    raise ValueError('each chain must have its own init file,'
                                     ' found duplicates in inits files list.')
                for i in range(len(self.inits)):
                    if not os.path.exists(self.inits[i]):
                        raise ValueError('no such file {}'.format(
                            self.inits[i]))

        if self.warmup_iters is not None:
            if self.warmup_iters < 0:
                raise ValueError(
                    'warmup_iters must be a non-negative integer'.format(
                        self.warmup_iters))
            if self.adapt_engaged and self.warmup_iters == 0:
                raise ValueError(
                    'adaptation requested but 0 warmup iterations specified, '
                    'must run warmup iterations')

        if self.sampling_iters is not None:
            if self.sampling_iters < 0:
                raise ValueError(
                    'sampling_iters must be a non-negative integer'.format(
                        self.sampling_iters))

        if self.warmup_schedule is not None:
            if self.warmup_iters is not None and self.warmup_iters < 1:
                raise ValueError(
                    'Config error: '
                    'warmup_schedule specified for 0 warmup iterations')
            if len(self.warmup_schedule) != 3 or sum(self.warmup_schedule) > 1:
                raise ValueError(
                    'warmup_schedule should be triple of precentages '
                    ' that sums to 1, e.g. (0.1, 0.8, 0.1), found {}'.format(
                        self.warmup_iters))
            for x in self.warmup_schedule:
                if x < 0 or x > 1:
                    raise ValueError(
                        'warmup_schedule should be triple of precentages that'
                        ' sum to 1, e.g. (0.1, 0.8, 0.1), found {}'.format(
                            self.warmup_schedule))

            num_warmup = 1000
            if self.warmup_iters is not None:
                num_warmup = self.warmup_iters
            self.init_buffer = math.floor(num_warmup * self.warmup_schedule[0])
            self.term_buffer = math.floor(num_warmup * self.warmup_schedule[2])

        if self.thin is not None:
            if self.thin < 1:
                raise ValueError('thin must be at least 1, found {}'.format(
                    self.thin))

        if self.max_treedepth is not None:
            if self.max_treedepth < 1:
                raise ValueError(
                    'max_treedepth must be at least 1, found {}'.format(
                        self.max_treedepth))

        if self.step_size is not None:
            if isinstance(self.step_size, (int, float)):
                if self.step_size < 0:
                    raise ValueError('step_size must be > 0, found {}'.format(
                        self.step_size))
            else:
                if len(self.step_size) != len(self.chain_ids):
                    raise ValueError(
                        'number of step_sizes must match number of chains '
                        ' found {} step_sizes for {} chains '.format(
                            len(self.step_size), len(self.chain_ids)))
                for i in range(len(self.step_size)):
                    if self.step_size[i] < 0:
                        raise ValueError(
                            'step_size must be > 0, found {}'.format(
                                self.step_size[i]))

        if self.metric is not None:
            dims = None
            if isinstance(self.metric, str):
                if self.metric in ['diag', 'diag_e']:
                    self.metric = 'diag_e'
                elif self.metric in ['dense', 'dense_e']:
                    self.metric = 'dense_e'
                else:
                    if not os.path.exists(self.metric):
                        raise ValueError('no such file {}'.format(self.metric))
                    dims = read_metric(self.metric)
            elif isinstance(self.metric, list):
                if len(self.metric) != len(self.chain_ids):
                    raise ValueError(
                        'number of metric files must match number of chains '
                        ' found {} metric files for {} chains '.format(
                            len(self.metric), len(self.chain_ids)))
                names_set = set(self.metric)
                if len(names_set) != len(self.metric):
                    raise ValueError(
                        'each chain must have its own metric file,'
                        ' found duplicates in metric files list.')
                for i in range(len(self.metric)):
                    if not os.path.exists(self.metric[i]):
                        raise ValueError('no such file {}'.format(
                            self.metric[i]))
                    if i == 0:
                        dims = read_metric(self.metric[i])
                    else:
                        dims2 = read_metric(self.metric[i])
                        if len(dims) != len(dims2):
                            raise ValueError('metrics files {}, {},'
                                             ' inconsistent metrics'.format(
                                                 self.metric[0],
                                                 self.metric[i]))
                        for j in range(len(dims)):
                            if dims[j] != dims2[j]:
                                raise ValueError(
                                    'metrics files {}, {},'
                                    ' inconsistent metrics'.format(
                                        self.metric[0], self.metric[i]))
            if dims is not None:
                if len(dims) > 2 or (len(dims) == 2 and dims[0] != dims[1]):
                    raise ValueError('bad metric specifiation')
                self.metric_file = self.metric
                if len(dims) == 1:
                    self.metric = 'diag_e'
                elif len(dims) == 2:
                    self.metric = 'dense_e'

        if self.adapt_delta is not None:
            if self.adapt_delta < 0.0 or self.adapt_delta > 1.0:
                raise ValueError('adapt_delta must be between 0 and 1,'
                                 ' found {}'.format(self.adapt_delta))
        pass