Example #1
0
    def posterior_to_xarray(self):
        """Convert the posterior to an xarray dataset."""
        # Do not make pyro a requirement
        from pyro.infer import EmpiricalMarginal

        try:  # Try pyro>=0.3 release syntax
            data = {
                name: utils.expand_dims(samples.enumerate_support().squeeze())
                if self.posterior.num_chains == 1 else
                samples.enumerate_support().squeeze()
                for name, samples in self.posterior.marginal(
                    sites=self.latent_vars).empirical.items()
            }
        except AttributeError:  # Use pyro<0.3 release syntax
            data = {}
            for var_name in self.latent_vars:
                # pylint: disable=no-member
                samples = EmpiricalMarginal(
                    self.posterior,
                    sites=var_name).get_samples_and_weights()[0]
                samples = samples.numpy().squeeze()
                data[var_name] = utils.expand_dims(samples)
        return dict_to_dataset(data,
                               library=self.pyro,
                               coords=self.coords,
                               dims=self.dims)
Example #2
0
 def observed_data_to_xarray(self):
     """Convert observed data to xarray."""
     from pyro.infer import EmpiricalMarginal
     data = {}
     for var_name in self.observed_vars:
         samples = EmpiricalMarginal(
             self.posterior, sites=var_name).get_samples_and_weights()[0]
         data[var_name] = np.expand_dims(samples.numpy().squeeze(), 0)
     return dict_to_dataset(data,
                            library=self.pyro,
                            coords=self.coords,
                            dims=self.dims)
Example #3
0
 def posterior_to_xarray(self):
     """Convert the posterior to an xarray dataset."""
     # Do not make pyro a requirement
     from pyro.infer import EmpiricalMarginal
     data = {}
     for var_name in self.latent_vars:
         samples = EmpiricalMarginal(
             self.posterior, sites=var_name).get_samples_and_weights()[0]
         data[var_name] = np.expand_dims(samples.numpy().squeeze(), 0)
     return dict_to_dataset(data,
                            library=self.pyro,
                            coords=self.coords,
                            dims=self.dims)
Example #4
0
    def observed_data_to_xarray(self):
        """Convert observed data to xarray."""
        from pyro.infer import EmpiricalMarginal

        try:  # Try pyro>=0.3 release syntax
            data = {
                name: np.expand_dims(samples.enumerate_support(), 0)
                for name, samples in self.posterior.marginal(
                    sites=self.observed_vars).empirical.items()
            }
        except AttributeError:  # Use pyro<0.3 release syntax
            data = {}
            for var_name in self.observed_vars:
                samples = EmpiricalMarginal(
                    self.posterior,
                    sites=var_name).get_samples_and_weights()[0]
                data[var_name] = np.expand_dims(samples.numpy().squeeze(), 0)
        return dict_to_dataset(data,
                               library=self.pyro,
                               coords=self.coords,
                               dims=self.dims)