def to_kanga(self, keys=None): keys = set(keys or self.vals.keys()) & set( ['sample', 'target_val', 'grad_val', 'accepted']) vals = {} for key in keys: if key == 'sample': vals[key] = self.get_samples().detach().cpu().numpy() elif key == 'target_val': vals[key] = self.get_target_vals().detach().cpu().numpy() elif key == 'grad_val': vals[key] = self.get_grad_vals().detach().cpu().numpy() elif key == 'accepted': vals[key] = np.array(self.vals['accepted']) return ChainArray(vals)
# %% Import packages import kanga.plots as ps from kanga.chains import ChainArray from bnn_mcmc_examples.examples.mlp.pima.setting2.constants import diagnostic_iter_thres from bnn_mcmc_examples.examples.mlp.pima.setting2.metropolis_hastings.constants import sampler_output_pilot_path from bnn_mcmc_examples.examples.mlp.pima.setting2.model import model # %% Load chain array chain_array = ChainArray.from_file(keys=['sample', 'accepted'], path=sampler_output_pilot_path) # %% Drop burn-in samples chain_array.vals['sample'] = chain_array.vals['sample'][ diagnostic_iter_thres:, :] chain_array.vals['accepted'] = chain_array.vals['accepted'][ diagnostic_iter_thres:] # %% Plot traces of simulated chain for i in range(model.num_params()): ps.trace(chain_array.get_param(i), title=r'Traceplot of $\theta_{{{}}}$'.format(i + 1), xlabel='Iteration', ylabel='Parameter value') # %% Plot running means of simulated chain