Exemplo n.º 1
0
    def sampling(self, verbose=False, **kwargs):
        """
        Args
        ----------
        verbose : bool
            If True, will print the suppressed print statements.
        kwargs : dict
            keyword arguments passed to PyStan's StanModel.sampling

        Returns
        ----------
        np.ndarray
            An array of unconstrained latent variables of the shape
            (num_samples, self.zlen)
        """
        try:
            assert kwargs.get('iter', 2) >= 2
            with utils.suppress_stdout_stderr(verbose):
                logging.getLogger("pystan").propagate = verbose
                self.fit = self.sm.sampling(data=self.data, **kwargs)
                logging.getLogger("pystan").propagate = False
            rez = self.unconstrain(self.fit.extract())
        except Exception as e:
            extra_message = ("Error during sampling from the model.")
            raise Exception(extra_message) from e
        return rez
Exemplo n.º 2
0
    def advi(self,
             algorithm='fullrank',
             verbose=False,
             return_constrained=False,
             **kwargs):
        """
            A function to run Stan's variational Bayes method (ADVI)
            Args
            ----------
                algorithm:  "fullrank" or "meanfield\
                                    "
                verbsoe:    Boolean
                            If True, prints the optimization messages from ADVI
                kwargs:     arguments for vb see
                            https://pystan.readthedocs.io/en/latest/api.html#pystan.StanModel.vb
                            for more details
            Returns
            ----------
                samples:    samples from the final posterior
        """
        try:
            with utils.suppress_stdout_stderr(verbose):
                rez = self.sm.vb(data=self.data, algorithm=algorithm, **kwargs)

        except Exception as e:
            raise e

        samples = np.array(rez["sampler_params"])[:-1, :].T

        if return_constrained is True:
            return self.constrained_array_to_dict(samples)
        else:
            return self.unconstrain(self.constrained_array_to_dict(samples))
Exemplo n.º 3
0
def get_compiled_model(model_code, model_name=None, verbose=False, **kwargs):
    """ A function to get PyStan compiled model.

    Args:
        model_code (string):
            Stan code
        model_name (string, optional):
            Defaults to None.
        verbose (bool, optional):
            Helps print some PyStan compilation warnings.
            Defaults to False.

    Returns:
        StanModel: Compiled StanModel for the model specified by model_code
    """
    if not os.path.exists('data/cached-models'):
        print('creating cached-models/ to save compiled stan models')
        os.makedirs('data/cached-models')

    cache_fn = (f'data/cached-models/'
                f'{utils.get_cache_fname(model_name, model_code)}.pkl')

    if os.path.isfile(cache_fn):
        try:
            with open(cache_fn, 'rb') as f:
                sm = pickle.load(f)
            print("Compiled model found.")
        except Exception as e:
            extra_message = (f"Error during re-loading the complied model."
                             f"Try recompiling the model. Changed the"
                             f"name of the model or delete"
                             f"the cached file at {cache_fn}.")
            raise Exception(extra_message) from e
    else:
        try:
            print("Cached model not found. Compiling...")
            with utils.suppress_stdout_stderr(verbose):
                sm = pystan.StanModel(model_code=model_code,
                                      model_name=model_name,
                                      **kwargs)
        except Exception as e:
            extra_message = ("Error during compilation."
                             f'Could not compile code for {model_code}')
            raise Exception(extra_message) from e

        print("Model compiled. Caching model.")
        with open(cache_fn, 'wb') as f:
            pickle.dump(sm, f)
    return sm
Exemplo n.º 4
0
    def __init__(self, model_code, data, model_name=None, verbose=False):
        """
            A class to interface with the autograd.

            Args
            ----------
                model_code (string):
                    Stan code for the model
                data (dict):
                    Data in the dictionary format
                model_name (string):
                    Name of the model for easier identification.
                    This along with code is used to cache compiled models.

                verbose (bool):
                    If True, it will print additional details involving
                    Stan compilation.

            Attributes
            ----------
                data(dict):
                     Model data
                model_name(string):
                     If None, only model_code is used to cache.
                model_code(string):
                     Stan code. Also, used to cache.
                sm(StanModel):
                      Complied StanModel instance.
                fit(StandModelFit4):
                    A StanModelFit4 instance obtained using self.sm.sampling
                keys(list):
                     Names of the unconstrained parameters.
                zlen(list):
                     Number of latent dimensions in the model.

        """
        # THIS IS SLOW! TURN OPTIMIZATION ON SOMEDAY!
        extra_compile_args = ['-O1', '-w', '-Wno-deprecated']

        self.data = data
        self.model_name = model_name
        self.model_code = model_code

        self.sm = get_compiled_model(model_code=self.model_code,
                                     extra_compile_args=extra_compile_args,
                                     model_name=self.model_name,
                                     verbose=verbose)
        try:
            with utils.suppress_stdout_stderr(False):
                self.fit = self.sm.sampling(data=self.data,
                                            iter=100,
                                            chains=1,
                                            init=0)
        except Exception as e:
            extra_message = (
                f'Error occurred during a sampling check for compiled model. '
                f"Try recompiling the model by removing cached model. "
                f'Cached model maybe stored at: '
                f'data/cached-models/'
                f'{utils.get_cache_fname(model_name, model_code)}.pkl')
            raise Exception(extra_message) from e

        self.keys = self.fit.unconstrained_param_names()
        self.zlen = len(self.keys)