def logpdf(self, *args, **kwargs) -> float: """Compute the value of the distribution's logpdf.""" logpdf, _, _, _ = compiler.compile_to_logpdf(self.graph, self.namespace) return logpdf(*args, **kwargs)
def logpdf_src(self) -> str: """Return the source code of the log-probability density funtion generated by the compiler. """ artifact = compiler.compile_to_logpdf(self.graph, self.namespace) return artifact.fn_source
def build_loglikelihood(model, **kwargs): artifact = compile_to_logpdf(model.graph, model.namespace) logpdf = artifact.compiled_fn loglikelihood = jax.partial(logpdf, **kwargs) return loglikelihood