def test_can_pass_vertex_to_vertex(jvm_view: JVMView) -> None: mu = Vertex(jvm_view.GaussianVertex, 0., 1.) gaussian = Vertex(jvm_view.GaussianVertex, mu, 1.) sample = gaussian.sample() assert type(sample) == numpy_types assert sample.shape == () assert sample.dtype == float
def test_id_str_of_downstream_vertex_is_higher_than_upstream( jvm_view: JVMView) -> None: hyper_params = Vertex(jvm_view.GaussianVertex, 0., 1.) gaussian = Vertex(jvm_view.GaussianVertex, 0., hyper_params) hyper_params_id = hyper_params.get_id() gaussian_id = gaussian.get_id() assert type(hyper_params_id) == tuple assert type(gaussian_id) == tuple assert hyper_params_id < gaussian_id
def test_get_vertex_id(jvm_view: JVMView) -> None: gaussian = Vertex(jvm_view.GaussianVertex, 0., 1.) java_id = gaussian.unwrap().getId().getValue() python_id = gaussian.get_id() assert all(value in python_id for value in java_id)
def test_cannot_pass_generic_to_vertex(jvm_view: JVMView) -> None: class GenericExampleClass: pass with pytest.raises( ValueError, match=r"Can't parse generic argument. Was given {}".format( GenericExampleClass)): Vertex( # type: ignore # this is expected to fail mypy jvm_view.GaussianVertex, "gaussian", GenericExampleClass(), GenericExampleClass())
def test_java_collections_to_generator(jvm_view: JVMView) -> None: gaussian = Vertex(jvm_view.GaussianVertex, 0., 1.) java_collections = gaussian.unwrap().getConnectedGraph() python_list = list(Vertex._to_generator(java_collections)) java_vertex_ids = [ Vertex._get_python_id(java_vertex) for java_vertex in java_collections ] assert java_collections.size() == len(python_list) assert all( type(element) == Vertex and element.get_id() in java_vertex_ids for element in python_list)
def test_can_pass_array_to_vertex(jvm_view: JVMView) -> None: gaussian = Vertex(jvm_view.GaussianVertex, "gaussian", [3, 3], Const(0.), Const(1.)) sample = gaussian.sample() assert sample.shape == (3, 3)
def test_construct_vertex_with_java_vertex(jvm_view: JVMView) -> None: java_vertex = Gaussian(0., 1.).unwrap() python_vertex = Vertex(java_vertex) assert tuple(java_vertex.getId().getValue()) == python_vertex.get_id()
def get_double_model_output_vertex(self, label: str) -> Vertex: label_unwrapped = VertexLabel(label).unwrap() result = self.unwrap().getDoubleModelOutputVertex(label_unwrapped) return Vertex(result)
def __wrap(vertices: JavaMap) -> Dict[str, Vertex]: return {k.getUnqualifiedName(): Vertex(v) for k, v in vertices.items()}
def test_can_pass_array_to_vertex(jvm_view: JVMView) -> None: gaussian = Vertex(jvm_view.GaussianVertex, [3, 3], 0., 1.) sample = gaussian.sample() assert sample.shape == (3, 3)
def test_can_pass_pandas_series_to_vertex(jvm_view): gaussian = Vertex(jvm_view.GaussianVertex, pd.Series(data=[0.1, 0.4]), pd.Series(data=[0.1, 0.4])) sample = gaussian.sample() assert sample.shape == (2, )
def test_can_pass_pandas_dataframe_to_vertex(jvm_view: JVMView) -> None: gaussian = Vertex(jvm_view.GaussianVertex, pd.DataFrame(data=[0.1, 0.4]), pd.DataFrame(data=[0.1, 0.4])) sample = gaussian.sample() assert sample.shape == (2, 1)
def test_can_pass_ndarray_to_vertex(jvm_view: JVMView) -> None: gaussian = Vertex(jvm_view.GaussianVertex, np.array([0.1, 0.4]), np.array([0.4, 0.5])) sample = gaussian.sample() assert sample.shape == (2, )
def test_get_connected_graph(jvm_view: JVMView) -> None: gaussian = Vertex(jvm_view.GaussianVertex, 0., 1.) connected_graph = set(gaussian.get_connected_graph()) assert len(connected_graph) == 3