assert all(value in python_id for value in java_id) def test_ids_are_reset() -> None: gaussian = Gaussian(0., 1.) set_deterministic_state() gaussian2 = Gaussian(0., 1.) assert gaussian.get_id() == gaussian2.get_id() @pytest.mark.parametrize("vertex, expected_type", [(Gaussian(0., 1.), np.floating), (UniformInt(0, 10), np.integer), (Bernoulli(0.5), np.bool_)]) @pytest.mark.parametrize( "value, assert_vertex_value_equals", [(np.array([[4]]), assert_vertex_value_equals_ndarray), (np.array([[5.]]), assert_vertex_value_equals_ndarray), (np.array([[True]]), assert_vertex_value_equals_ndarray), (np.array([[1, 2], [3, 4]]), assert_vertex_value_equals_ndarray), (pd.Series(data=[4]), assert_vertex_value_equals_pandas), (pd.Series(data=[5.]), assert_vertex_value_equals_pandas), (pd.Series(data=[True]), assert_vertex_value_equals_pandas), (pd.Series(data=[1, 2, 3]), assert_vertex_value_equals_pandas), (pd.Series(data=[1., 2., 3.]), assert_vertex_value_equals_pandas), (pd.Series(data=[True, False, False]), assert_vertex_value_equals_pandas), (pd.DataFrame(data=[[4]]), assert_vertex_value_equals_pandas), (pd.DataFrame(data=[[5.]]), assert_vertex_value_equals_pandas), (pd.DataFrame(data=[[True]]), assert_vertex_value_equals_pandas),
def create_vertex(item: SequenceItem) -> None: v = Bernoulli(0.5) v.set_label(vertexLabel) item.add(v)
from typing import Union import pytest import numpy as np import pandas as pd from keanu.vartypes import tensor_arg_types from keanu.vertex import Bernoulli, If, Gaussian, Const, Double, Poisson, Integer, Boolean, Exponential, Vertex, Uniform @pytest.mark.parametrize("predicate", [ True, np.array([True, False]), pd.Series([True, False]), Bernoulli(0.5), Const(np.array([True, False])) ]) @pytest.mark.parametrize("data", [ 1., np.array([1., 2.]), pd.Series([1., 2.]), Exponential(1.), Const(np.array([1., 2.])) ]) def test_you_can_create_a_double_valued_if( predicate: Union[tensor_arg_types, Vertex], data: Union[tensor_arg_types, Vertex]) -> None: thn = data els = data result = If(predicate, thn, els) assert type(result) == Double