예제 #1
0
    def test_plot_graphs(self):
        r"""
        Plot all graphs which have coordinates.
        With and without signal.
        With both backends.
        """

        # Graphs who are not embedded, i.e., have no coordinates.
        COORDS_NO = {
            'Graph',
            'BarabasiAlbert',
            'ErdosRenyi',
            'FullConnected',
            'RandomRegular',
            'StochasticBlockModel',
            }

        # Coordinates are not in 2D or 3D.
        COORDS_WRONG_DIM = {'ImgPatches'}

        Gs = []
        for classname in set(graphs.__all__) - COORDS_NO - COORDS_WRONG_DIM:
            Graph = getattr(graphs, classname)

            # Classes who require parameters.
            if classname == 'NNGraph':
                Xin = np.arange(90).reshape(30, 3)
                Gs.append(Graph(Xin))
            elif classname in ['ImgPatches', 'Grid2dImgPatches']:
                Gs.append(Graph(img=self._img, patch_shape=(3, 3)))
            elif classname == 'LineGraph':
                Gs.append(Graph(graphs.Sensor(20, seed=42)))
            else:
                Gs.append(Graph())

            # Add more test cases.
            if classname == 'TwoMoons':
                Gs.append(Graph(moontype='standard'))
                Gs.append(Graph(moontype='synthesized'))
            elif classname == 'Cube':
                Gs.append(Graph(nb_dim=2))
                Gs.append(Graph(nb_dim=3))
            elif classname == 'DavidSensorNet':
                Gs.append(Graph(N=64))
                Gs.append(Graph(N=500))
                Gs.append(Graph(N=128))

        for G in Gs:
            self.assertTrue(hasattr(G, 'coords'))
            self.assertEqual(G.N, G.coords.shape[0])

            signal = np.arange(G.N) + 0.3

            G.plot(backend='pyqtgraph')
            G.plot(backend='matplotlib')
            G.plot(signal, backend='pyqtgraph')
            G.plot(signal, backend='matplotlib')
            plotting.close_all()
예제 #2
0
    def test_plot_graphs(self):
        r"""
        Plot all graphs which have coordinates.
        With and without signal.
        With both backends.
        """

        # Graphs who are not embedded, i.e., have no coordinates.
        COORDS_NO = {
            'Graph',
            'BarabasiAlbert',
            'ErdosRenyi',
            'FullConnected',
            'RandomRegular',
            'StochasticBlockModel',
            }

        # Coordinates are not in 2D or 3D.
        COORDS_WRONG_DIM = {'ImgPatches'}

        Gs = []
        for classname in set(graphs.__all__) - COORDS_NO - COORDS_WRONG_DIM:
            Graph = getattr(graphs, classname)

            # Classes who require parameters.
            if classname == 'NNGraph':
                Xin = np.arange(90).reshape(30, 3)
                Gs.append(Graph(Xin))
            elif classname in ['ImgPatches', 'Grid2dImgPatches']:
                Gs.append(Graph(img=self._img, patch_shape=(3, 3)))
            else:
                Gs.append(Graph())

            # Add more test cases.
            if classname == 'TwoMoons':
                Gs.append(Graph(moontype='standard'))
                Gs.append(Graph(moontype='synthesized'))
            elif classname == 'Cube':
                Gs.append(Graph(nb_dim=2))
                Gs.append(Graph(nb_dim=3))
            elif classname == 'DavidSensorNet':
                Gs.append(Graph(N=64))
                Gs.append(Graph(N=500))
                Gs.append(Graph(N=128))

        for G in Gs:
            self.assertTrue(hasattr(G, 'coords'))
            self.assertEqual(G.N, G.coords.shape[0])

            signal = np.arange(G.N) + 0.3

            G.plot(backend='pyqtgraph')
            G.plot(backend='matplotlib')
            G.plot(signal, backend='pyqtgraph')
            G.plot(signal, backend='matplotlib')
            plotting.close_all()
예제 #3
0
 def test_all_filters(self):
     """Plot all filters."""
     for classname in dir(filters):
         if not classname[0].isupper():
             # Not a Filter class but a submodule or private stuff.
             continue
         Filter = getattr(filters, classname)
         if classname in ['Filter', 'Modulation', 'Gabor']:
             g = Filter(self._graph, filters.Heat(self._graph))
         else:
             g = Filter(self._graph)
         g.plot()
         plotting.close_all()
예제 #4
0
 def tearDown(cls):
     plotting.close_all()
예제 #5
0
 def test_show_close(self):
     G = graphs.Sensor()
     G.plot()
     plotting.show(block=False)  # Don't block or the test will halt.
     plotting.close()
     plotting.close_all()
예제 #6
0
 def test_show_close(self):
     G = graphs.Sensor()
     G.plot()
     plotting.show(block=False)  # Don't block or the test will halt.
     plotting.close()
     plotting.close_all()