def sample(net: BayesNet, sample_from: Iterable[Vertex], algo: str = 'metropolis', proposal_distribution: str = None, proposal_distribution_sigma: numpy_types = None, proposal_listeners=[], draws: int = 500, drop: int = 0, down_sample_interval: int = 1, plot: bool = False, ax: Any = None) -> sample_types: sampling_algorithm: JavaObject = build_sampling_algorithm( algo, proposal_distribution, proposal_distribution_sigma, proposal_listeners) vertices_unwrapped: JavaList = k.to_java_object_list(sample_from) network_samples: JavaObject = sampling_algorithm.getPosteriorSamples( net.unwrap(), vertices_unwrapped, draws).drop(drop).downSample(down_sample_interval) vertex_samples = { Vertex._get_python_label(vertex_unwrapped): list( map(Tensor._to_ndarray, network_samples.get(vertex_unwrapped).asList())) for vertex_unwrapped in vertices_unwrapped } if plot: traceplot(vertex_samples, ax=ax) return vertex_samples
def _samples_generator(sample_iterator: JavaObject, vertices_unwrapped: JavaList, live_plot: bool, refresh_every: int, ax: Any) -> sample_generator_types: traces = [] x0 = 0 while (True): network_sample = sample_iterator.next() sample = { Vertex._get_python_label(vertex_unwrapped): Tensor._to_ndarray(network_sample.get(vertex_unwrapped)) for vertex_unwrapped in vertices_unwrapped } if live_plot: traces.append(sample) if len(traces) % refresh_every == 0: joined_trace = { k: [t[k] for t in traces] for k in sample.keys() } if ax is None: ax = traceplot(joined_trace, x0=x0) else: traceplot(joined_trace, ax=ax, x0=x0) x0 += refresh_every traces = [] yield sample