コード例 #1
0
def test_bad_inputs():
    np.random.seed(1)
    single_graph = er_np(100, 0.2)
    different_size_graphs = [er_np(100, 0.2)] + [er_np(200, 0.2)]

    with pytest.raises(TypeError):
        "Invalid unscaled"
        mase = MultipleASE(scaled="1")

    with pytest.raises(TypeError):
        wrong_diag_aug = "True"
        mase = MultipleASE(diag_aug=wrong_diag_aug)

    with pytest.raises(ValueError):
        "Test single graph input"
        MultipleASE().fit(single_graph)

    with pytest.raises(ValueError):
        "Test 3-d tensor with 1 graph"
        single_graph_tensor = single_graph.reshape(1, 100, -1)
        MultipleASE().fit(single_graph_tensor)

    with pytest.raises(ValueError):
        "Empty list"
        MultipleASE().fit([])

    with pytest.raises(ValueError):
        "Test graphs with different sizes"
        MultipleASE().fit(different_size_graphs)
コード例 #2
0
ファイル: test_plot_matrix.py プロジェクト: zeou1/graspy
def test_matrix_output():
    """
    simple function to see if plot is made without errors
    """
    X = er_np(10, 0.5)
    meta = pd.DataFrame({
        "hemisphere": np.random.randint(2, size=10),
        "region": np.random.randint(2, size=10),
        "cell_size": np.random.randint(10, size=10),
    })
    ax = matrixplot(X, col_meta=meta, row_meta=meta)
    ax = matrixplot(X, col_meta=meta, row_meta=meta, row_group="hemisphere")
    ax = matrixplot(
        X,
        col_meta=meta,
        row_meta=meta,
        row_group="hemisphere",
        col_group_order="size",
    )
    ax = matrixplot(
        X,
        col_meta=meta,
        row_meta=meta,
        col_group="hemisphere",
        row_item_order="cell_size",
    )
コード例 #3
0
 def test_gridplot_outputs(self):
     """
     simple function to see if plot is made without errors
     """
     X = [er_np(10, 0.5) for _ in range(2)]
     labels = ["Random A", "Random B"]
     fig = gridplot(X, labels)
     fig = gridplot(X, labels, transform="zero-boost")
     fig = gridplot(X, labels, "simple-all", title="Test", font_scale=0.9)
コード例 #4
0
def test_omni_matrix_symmetric():
    np.random.seed(3)
    n = 15
    p = 0.4

    n_graphs = [2, 5, 10]
    for n in n_graphs:
        graphs = [er_np(n, p) for _ in range(n)]
        output = _get_omni_matrix(graphs)
        assert is_symmetric(output)
コード例 #5
0
    def test_networkplot_inputs(self):
        X = er_np(15, 0.5)
        x = np.random.rand(15, 1)
        y = np.random.rand(15, 1)
        with self.assertRaises(beartype.roar.BeartypeCallHintParamViolation):
            with self.assertRaises(TypeError):
                networkplot(adjacency="test", x=x, y=y)

            with self.assertRaises(TypeError):
                networkplot(adjacency=X, x=["A"], y=["A"])

            with self.assertRaises(TypeError):
                networkplot(adjacency=csr_matrix(X),
                            x="source",
                            y="target",
                            node_data="data")

            with self.assertRaises(TypeError):
                networkplot(adjacency=X, x=x, y=y, node_data="data")

            with self.assertRaises(TypeError):
                networkplot(adjacency=X, x=x, y=y, node_hue=(5, 5))

            with self.assertRaises(TypeError):
                networkplot(adjacency=X, x=x, y=y, palette=4)

            with self.assertRaises(TypeError):
                networkplot(adjacency=X, x=x, y=y, node_size=(5, 5))

            with self.assertRaises(TypeError):
                networkplot(adjacency=X, x=x, y=y, node_sizes=4)

            with self.assertRaises(TypeError):
                networkplot(adjacency=X, x=x, y=y, node_alpha="test")

            with self.assertRaises(TypeError):
                networkplot(adjacency=X, x=x, y=y, edge_hue=4)

            with self.assertRaises(TypeError):
                networkplot(adjacency=csr_matrix(X),
                            x=x,
                            y=y,
                            edge_alpha="test")

            with self.assertRaises(TypeError):
                networkplot(adjacency=X, x=x, y=y, edge_linewidth="test")

            with self.assertRaises(TypeError):
                networkplot(adjacency=X, x=x, y=y, ax="test")

            with self.assertRaises(TypeError):
                networkplot(adjacency=X, x=x, y=y, legend=4)
コード例 #6
0
def test_diag_aug():
    np.random.seed(5)
    n = 100
    p = 0.25

    graphs_list = [er_np(n, p) for _ in range(2)]
    graphs_arr = np.array(graphs_list)

    # Test that array and list inputs results in same embeddings
    mase_arr = MultipleASE(diag_aug=True).fit_transform(graphs_arr)
    mase_list = MultipleASE(diag_aug=True).fit_transform(graphs_list)

    assert array_equal(mase_list, mase_arr)
コード例 #7
0
def test_diag_aug():
    np.random.seed(5)
    n = 100
    p = 0.25

    graphs_list = [er_np(n, p) for _ in range(2)]
    graphs_arr = np.array(graphs_list)

    # Test that array and list inputs results in same embeddings
    omni_arr = OmnibusEmbed(diag_aug=True).fit_transform(graphs_arr)
    omni_list = OmnibusEmbed(diag_aug=True).fit_transform(graphs_list)

    assert array_equal(omni_list, omni_arr)
コード例 #8
0
def test_heatmap_output():
    """
    simple function to see if plot is made without errors
    """
    X = er_np(10, 0.5)
    xticklabels = ["Dimension {}".format(i) for i in range(10)]
    yticklabels = ["Dimension {}".format(i) for i in range(10)]

    fig = heatmap(X, transform="log", xticklabels=xticklabels, yticklabels=yticklabels)
    fig = heatmap(X, transform="zero-boost")
    fig = heatmap(X, transform="simple-all")
    fig = heatmap(X, transform="simple-nonzero")
    fig = heatmap(X, transform="binarize")
    fig = heatmap(X, cmap="gist_rainbow")
コード例 #9
0
    def test_gridplot_inputs(self):
        X = [er_np(10, 0.5)]
        labels = ["ER(10, 0.5)"]

        with self.assertRaises(TypeError):
            gridplot(X="input", labels=labels)

        with self.assertRaises(ValueError):
            gridplot(X, labels=["a", "b"])

        # transform
        with self.assertRaises(ValueError):
            transform = "bad transform"
            gridplot(X, labels=labels, transform=transform)
コード例 #10
0
ファイル: test_plot_matrix.py プロジェクト: zeou1/graspy
def test_adjplot_output():
    """
    simple function to see if plot is made without errors
    """
    X = er_np(10, 0.5)
    meta = pd.DataFrame({
        "hemisphere": np.random.randint(2, size=10),
        "region": np.random.randint(2, size=10),
        "cell_size": np.random.randint(10, size=10),
    })
    ax = adjplot(X, meta=meta)
    ax = adjplot(X, meta=meta, group="hemisphere")
    ax = adjplot(X, meta=meta, group="hemisphere", group_order="size")
    ax = adjplot(X, meta=meta, group="hemisphere", item_order="cell_size")
コード例 #11
0
def _test_sbm_er_binary(self,
                        method,
                        P,
                        directed=False,
                        sparse=False,
                        *args,
                        **kwargs):
    np.random.seed(8888)

    num_sims = 50
    verts = 200
    communities = 2

    verts_per_community = [100, 100]

    sbm_wins = 0
    er_wins = 0
    for sim in range(0, num_sims):
        sbm_sample = sbm(verts_per_community, P, directed=directed)
        er = er_np(verts, 0.5, directed=directed)
        if sparse:
            sbm_sample = csr_matrix(sbm_sample)
            er = csr_matrix(er)
        embed_sbm = method(n_components=2, concat=directed)
        embed_er = method(n_components=2, concat=directed)

        labels_sbm = np.zeros((verts), dtype=np.int8)
        labels_er = np.zeros((verts), dtype=np.int8)
        labels_sbm[100:] = 1
        labels_er[100:] = 1

        X_sbm = embed_sbm.fit_transform(sbm_sample)
        X_er = embed_er.fit_transform(er)

        if directed:
            self.assertEqual(X_sbm.shape, (verts, 2 * communities))
            self.assertEqual(X_er.shape, (verts, 2 * communities))
        else:
            self.assertEqual(X_sbm.shape, (verts, communities))
            self.assertEqual(X_er.shape, (verts, communities))

        aris = _kmeans_comparison((X_sbm, X_er), (labels_sbm, labels_er),
                                  communities)
        sbm_wins = sbm_wins + (aris[0] > aris[1])
        er_wins = er_wins + (aris[0] < aris[1])

    self.assertTrue(sbm_wins > er_wins)
コード例 #12
0
ファイル: test_plot_matrix.py プロジェクト: zeou1/graspy
def test_adjplot_inputs():
    X = er_np(100, 0.5)
    meta = pd.DataFrame({
        "hemisphere": np.random.randint(2, size=100),
        "region": np.random.randint(2, size=100),
        "cell_size": np.random.randint(10, size=100),
    })

    # test matrix
    with pytest.raises(TypeError):
        adjplot(data="input", meta=meta)
    with pytest.raises(ValueError):
        adjplot(data=np.zeros((2, 2, 2)), meta=meta)

    # test meta
    with pytest.raises(ValueError):
        bad_meta = pd.DataFrame({
            "hemisphere": np.random.randint(2, size=1),
            "region": np.random.randint(2, size=1),
            "cell_size": np.random.randint(10, size=1),
        })
        adjplot(X, meta=bad_meta)

    # test plot type
    with pytest.raises(ValueError):
        adjplot(X, plot_type="bad plottype")

    # test sorting_kws
    with pytest.raises(TypeError):
        adjplot(X, meta=meta, group=123)
    with pytest.raises(TypeError):
        adjplot(X, meta=meta, group_order=123)
    with pytest.raises(TypeError):
        adjplot(X, meta=meta, item_order=123)
    with pytest.raises(TypeError):
        adjplot(X, meta=meta, color=123)
    with pytest.raises(ValueError):
        adjplot(X, meta=meta, group="bad value")
    with pytest.raises(ValueError):
        adjplot(X, meta=meta, group_order="bad value")
    with pytest.raises(ValueError):
        adjplot(X, meta=meta, item_order="bad value")
    with pytest.raises(ValueError):
        adjplot(X, meta=meta, color="bad value")
コード例 #13
0
    def test_networkplot_outputs_str(self):
        X = er_np(15, 0.7)
        node_df = pd.DataFrame(index=["node {}".format(i) for i in range(15)])
        node_df.loc[:, "source"] = np.random.rand(15, 1)
        node_df.loc[:, "target"] = np.random.rand(15, 1)
        node_df.loc[:, "hue"] = np.random.randint(2, size=15)
        palette = {0: (0.8, 0.4, 0.2), 1: (0, 0.9, 0.4)}
        size = np.random.rand(15)
        sizes = (10, 200)

        fig = networkplot(
            adjacency=X,
            x="source",
            y="target",
            node_data=node_df,
            node_hue="hue",
            palette=palette,
            node_size=size,
            node_sizes=sizes,
            node_alpha=0.5,
            edge_alpha=0.4,
            edge_linewidth=0.6,
        )
コード例 #14
0
    def test_networkplot_outputs_int(self):
        X = er_np(15, 0.5)
        xarray = np.random.rand(15, 1)
        yarray = np.random.rand(15, 1)
        xstring = "source"
        ystring = "target"
        node_df = pd.DataFrame(index=range(X.shape[0]))
        node_df.loc[:, "source"] = xarray
        node_df.loc[:, "target"] = yarray
        hue = np.random.randint(2, size=15)
        palette = {0: (0.8, 0.4, 0.2), 1: (0, 0.9, 0.4)}
        size = np.random.rand(15)
        sizes = (10, 200)

        fig = networkplot(adjacency=X, x=xarray, y=yarray)
        fig = networkplot(adjacency=csr_matrix(X), x=xarray, y=yarray)
        fig = networkplot(adjacency=X, x=xstring, y=ystring, node_data=node_df)
        fig = networkplot(adjacency=csr_matrix(X),
                          x=xstring,
                          y=ystring,
                          node_data=node_df)
        fig = plt.figure()
        ax = fig.add_subplot(211)
        fig = networkplot(
            adjacency=X,
            x=xarray,
            y=yarray,
            node_hue=hue,
            palette=palette,
            node_size=size,
            node_sizes=sizes,
            node_alpha=0.5,
            edge_alpha=0.4,
            edge_linewidth=0.6,
            ax=ax,
        )
コード例 #15
0
ファイル: test_base_embed.py プロジェクト: zeou1/graspy
 def setup_class(cls):
     # simple ERxN graph
     cls.n = 20
     cls.p = 0.5
     cls.A = er_np(cls.n, cls.p, directed=True, loops=False)
コード例 #16
0
def test_common_inputs():
    X = er_np(100, 0.5)
    grid_labels = ["Test1"]

    # test figsize
    with pytest.raises(TypeError):
        figsize = "bad figsize"
        heatmap(X, figsize=figsize)

    # test height
    height = "1"
    with pytest.raises(TypeError):
        gridplot([X], grid_labels, height=height)
    with pytest.raises(TypeError):
        pairplot(X, height=height)

    # test title
    title = 1
    with pytest.raises(TypeError):
        heatmap(X, title=title)
    with pytest.raises(TypeError):
        gridplot([X], grid_labels, title=title)
    with pytest.raises(TypeError):
        pairplot(X, title=title)

    # test context
    context = 123
    with pytest.raises(TypeError):
        heatmap(X, context=context)
    with pytest.raises(TypeError):
        gridplot([X], grid_labels, context=context)
    with pytest.raises(TypeError):
        pairplot(X, context=context)

    context = "journal"
    with pytest.raises(ValueError):
        heatmap(X, context=context)
    with pytest.raises(ValueError):
        gridplot([X], grid_labels, context=context)
    with pytest.raises(ValueError):
        pairplot(X, context=context)

    # test font scales
    font_scales = ["1", []]
    for font_scale in font_scales:
        with pytest.raises(TypeError):
            heatmap(X, font_scale=font_scale)
        with pytest.raises(TypeError):
            gridplot([X], grid_labels, font_scale=font_scale)
        with pytest.raises(TypeError):
            pairplot(X, cont_scale=font_scale)

    # ticklabels
    with pytest.raises(TypeError):
        xticklabels = "labels"
        yticklabels = "labels"
        heatmap(X, xticklabels=xticklabels, yticklabels=yticklabels)

    with pytest.raises(ValueError):
        xticklabels = ["{}".format(i) for i in range(5)]
        yticklabels = ["{}".format(i) for i in range(5)]
        heatmap(X, xticklabels=xticklabels, yticklabels=yticklabels)

    with pytest.raises(TypeError):
        heatmap(X, title_pad="f")

    with pytest.raises(TypeError):
        gridplot([X], title_pad="f")

    with pytest.raises(TypeError):
        heatmap(X, hier_label_fontsize="f")

    with pytest.raises(TypeError):
        gridplot([X], hier_label_fontsize="f")
コード例 #17
0
    def test_common_inputs(self):
        X = er_np(100, 0.5)
        x = np.random.rand(100, 1)
        y = np.random.rand(100, 1)
        grid_labels = ["Test1"]

        # test figsize
        figsize = "bad figsize"
        with self.assertRaises(TypeError):
            heatmap(X, figsize=figsize)
        with self.assertRaises(beartype.roar.BeartypeCallHintParamViolation):
            with self.assertRaises(TypeError):
                networkplot(adjacency=X, x=x, y=y, figsize=figsize)

        # test height
        height = "1"
        with self.assertRaises(TypeError):
            gridplot([X], grid_labels, height=height)
        with self.assertRaises(TypeError):
            pairplot(X, height=height)

        # test title
        title = 1
        with self.assertRaises(TypeError):
            heatmap(X, title=title)
        with self.assertRaises(TypeError):
            gridplot([X], grid_labels, title=title)
        with self.assertRaises(TypeError):
            pairplot(X, title=title)
        with self.assertRaises(beartype.roar.BeartypeCallHintParamViolation):
            with self.assertRaises(TypeError):
                networkplot(adjacency=X, x=x, y=y, title=title)

        # test context
        context = 123
        with self.assertRaises(TypeError):
            heatmap(X, context=context)
        with self.assertRaises(TypeError):
            gridplot([X], grid_labels, context=context)
        with self.assertRaises(TypeError):
            pairplot(X, context=context)
        with self.assertRaises(beartype.roar.BeartypeCallHintParamViolation):
            with self.assertRaises(TypeError):
                networkplot(adjacency=X, x=x, y=y, context=context)

        context = "journal"
        with self.assertRaises(ValueError):
            heatmap(X, context=context)
        with self.assertRaises(ValueError):
            gridplot([X], grid_labels, context=context)
        with self.assertRaises(ValueError):
            pairplot(X, context=context)
        with self.assertRaises(ValueError):
            networkplot(adjacency=X, x=x, y=y, context=context)

        # test font scales
        font_scales = ["1", []]
        for font_scale in font_scales:
            with self.assertRaises(TypeError):
                heatmap(X, font_scale=font_scale)
            with self.assertRaises(TypeError):
                gridplot([X], grid_labels, font_scale=font_scale)
            with self.assertRaises(TypeError):
                pairplot(X, font_scale=font_scale)
            with self.assertRaises(
                    beartype.roar.BeartypeCallHintParamViolation):
                with self.assertRaises(TypeError):
                    networkplot(adjacency=X, x=x, y=y, font_scale=font_scale)

        # ticklabels
        with self.assertRaises(TypeError):
            xticklabels = "labels"
            yticklabels = "labels"
            heatmap(X, xticklabels=xticklabels, yticklabels=yticklabels)

        with self.assertRaises(ValueError):
            xticklabels = ["{}".format(i) for i in range(5)]
            yticklabels = ["{}".format(i) for i in range(5)]
            heatmap(X, xticklabels=xticklabels, yticklabels=yticklabels)

        with self.assertRaises(TypeError):
            heatmap(X, title_pad="f")

        with self.assertRaises(TypeError):
            gridplot([X], title_pad="f")

        with self.assertRaises(TypeError):
            heatmap(X, hier_label_fontsize="f")

        with self.assertRaises(TypeError):
            gridplot([X], hier_label_fontsize="f")