コード例 #1
0
ファイル: test_vertex.py プロジェクト: luke14free/keanu
def test_can_pass_vertex_to_vertex(jvm_view: JVMView) -> None:
    mu = Vertex(jvm_view.GaussianVertex, 0., 1.)
    gaussian = Vertex(jvm_view.GaussianVertex, mu, 1.)
    sample = gaussian.sample()

    assert type(sample) == numpy_types
    assert sample.shape == ()
    assert sample.dtype == float
コード例 #2
0
ファイル: test_vertex.py プロジェクト: luke14free/keanu
def test_id_str_of_downstream_vertex_is_higher_than_upstream(
        jvm_view: JVMView) -> None:
    hyper_params = Vertex(jvm_view.GaussianVertex, 0., 1.)
    gaussian = Vertex(jvm_view.GaussianVertex, 0., hyper_params)

    hyper_params_id = hyper_params.get_id()
    gaussian_id = gaussian.get_id()

    assert type(hyper_params_id) == tuple
    assert type(gaussian_id) == tuple

    assert hyper_params_id < gaussian_id
コード例 #3
0
ファイル: test_vertex.py プロジェクト: luke14free/keanu
def test_get_vertex_id(jvm_view: JVMView) -> None:
    gaussian = Vertex(jvm_view.GaussianVertex, 0., 1.)

    java_id = gaussian.unwrap().getId().getValue()
    python_id = gaussian.get_id()

    assert all(value in python_id for value in java_id)
コード例 #4
0
ファイル: test_vertex.py プロジェクト: shazbots/keanu
def test_cannot_pass_generic_to_vertex(jvm_view: JVMView) -> None:
    class GenericExampleClass:
        pass

    with pytest.raises(
            ValueError,
            match=r"Can't parse generic argument. Was given {}".format(
                GenericExampleClass)):
        Vertex(  # type: ignore # this is expected to fail mypy
            jvm_view.GaussianVertex, "gaussian", GenericExampleClass(),
            GenericExampleClass())
コード例 #5
0
ファイル: test_vertex.py プロジェクト: luke14free/keanu
def test_java_collections_to_generator(jvm_view: JVMView) -> None:
    gaussian = Vertex(jvm_view.GaussianVertex, 0., 1.)

    java_collections = gaussian.unwrap().getConnectedGraph()
    python_list = list(Vertex._to_generator(java_collections))

    java_vertex_ids = [
        Vertex._get_python_id(java_vertex) for java_vertex in java_collections
    ]

    assert java_collections.size() == len(python_list)
    assert all(
        type(element) == Vertex and element.get_id() in java_vertex_ids
        for element in python_list)
コード例 #6
0
ファイル: test_vertex.py プロジェクト: shazbots/keanu
def test_can_pass_array_to_vertex(jvm_view: JVMView) -> None:
    gaussian = Vertex(jvm_view.GaussianVertex, "gaussian", [3, 3], Const(0.),
                      Const(1.))
    sample = gaussian.sample()

    assert sample.shape == (3, 3)
コード例 #7
0
ファイル: test_vertex.py プロジェクト: VirtMarket/keanu
def test_construct_vertex_with_java_vertex(jvm_view: JVMView) -> None:
    java_vertex = Gaussian(0., 1.).unwrap()
    python_vertex = Vertex(java_vertex)

    assert tuple(java_vertex.getId().getValue()) == python_vertex.get_id()
コード例 #8
0
ファイル: lambda_model.py プロジェクト: luke14free/keanu
 def get_double_model_output_vertex(self, label: str) -> Vertex:
     label_unwrapped = VertexLabel(label).unwrap()
     result = self.unwrap().getDoubleModelOutputVertex(label_unwrapped)
     return Vertex(result)
コード例 #9
0
ファイル: lambda_model.py プロジェクト: luke14free/keanu
 def __wrap(vertices: JavaMap) -> Dict[str, Vertex]:
     return {k.getUnqualifiedName(): Vertex(v) for k, v in vertices.items()}
コード例 #10
0
ファイル: test_vertex.py プロジェクト: luke14free/keanu
def test_can_pass_array_to_vertex(jvm_view: JVMView) -> None:
    gaussian = Vertex(jvm_view.GaussianVertex, [3, 3], 0., 1.)
    sample = gaussian.sample()

    assert sample.shape == (3, 3)
コード例 #11
0
ファイル: test_vertex.py プロジェクト: luke14free/keanu
def test_can_pass_pandas_series_to_vertex(jvm_view):
    gaussian = Vertex(jvm_view.GaussianVertex, pd.Series(data=[0.1, 0.4]),
                      pd.Series(data=[0.1, 0.4]))
    sample = gaussian.sample()

    assert sample.shape == (2, )
コード例 #12
0
ファイル: test_vertex.py プロジェクト: luke14free/keanu
def test_can_pass_pandas_dataframe_to_vertex(jvm_view: JVMView) -> None:
    gaussian = Vertex(jvm_view.GaussianVertex, pd.DataFrame(data=[0.1, 0.4]),
                      pd.DataFrame(data=[0.1, 0.4]))
    sample = gaussian.sample()

    assert sample.shape == (2, 1)
コード例 #13
0
ファイル: test_vertex.py プロジェクト: luke14free/keanu
def test_can_pass_ndarray_to_vertex(jvm_view: JVMView) -> None:
    gaussian = Vertex(jvm_view.GaussianVertex, np.array([0.1, 0.4]),
                      np.array([0.4, 0.5]))
    sample = gaussian.sample()

    assert sample.shape == (2, )
コード例 #14
0
ファイル: test_vertex.py プロジェクト: luke14free/keanu
def test_get_connected_graph(jvm_view: JVMView) -> None:
    gaussian = Vertex(jvm_view.GaussianVertex, 0., 1.)
    connected_graph = set(gaussian.get_connected_graph())

    assert len(connected_graph) == 3