コード例 #1
0
 def test_show_examples_supervised(self, _):
     with testing.mock_data(num_examples=20):
         ds, ds_info = load.load('imagenet2012',
                                 split='train',
                                 with_info=True,
                                 as_supervised=True)
     visualization.show_examples(ds, ds_info)
コード例 #2
0
def _as_df(ds_name: str, **kwargs) -> pandas.DataFrame:
    """Loads the dataset as `pandas.DataFrame`."""
    with testing.mock_data(num_examples=3):
        ds, ds_info = load.load(ds_name,
                                split='train',
                                with_info=True,
                                **kwargs)
    df = as_dataframe.as_dataframe(ds, ds_info)
    return df
コード例 #3
0
    def test_show_examples_graph_with_colors_and_labels(self, _):
        with testing.mock_data(num_examples=20):
            ds, ds_info = load.load('ogbg_molpcba',
                                    split='train',
                                    with_info=True)

        # Dictionaries used to map nodes and edges to colors.
        atomic_numbers_to_elements = {
            6: 'C',
            7: 'N',
            8: 'O',
            9: 'F',
            14: 'Si',
            15: 'P',
            16: 'S',
            17: 'Cl',
            35: 'Br,'
        }
        elements_to_colors = {
            element: f'C{index}'
            for index, element in enumerate(
                atomic_numbers_to_elements.values())
        }
        bond_types_to_colors = {num: f'C{num}' for num in range(4)}

        # Node colors are atomic numbers.
        def node_color_fn(graph):
            atomic_numbers = 1 + graph['node_feat'][:, 0].numpy()
            return {
                index:
                elements_to_colors[atomic_numbers_to_elements[atomic_number]]
                for index, atomic_number in enumerate(atomic_numbers)
            }

        # Node labels are element names.
        def node_label_fn(graph):
            atomic_numbers = 1 + graph['node_feat'][:, 0].numpy()
            return {
                index: atomic_numbers_to_elements[atomic_number]
                for index, atomic_number in enumerate(atomic_numbers)
            }

        # Edge colors are bond types.
        def edge_color_fn(graph):
            bonds = graph['edge_index'].numpy()
            bond_types = graph['edge_feat'][:, 0].numpy()
            return {
                tuple(bond): bond_types_to_colors[bond_type]
                for bond, bond_type in zip(bonds, bond_types)
            }

        visualization.show_examples(ds,
                                    ds_info,
                                    node_color_fn=node_color_fn,
                                    node_label_fn=node_label_fn,
                                    edge_color_fn=edge_color_fn)
コード例 #4
0
    def test_show_examples(self, mock_fig):
        with testing.mock_data(num_examples=20):
            ds, ds_info = registered.load('imagenet2012',
                                          split='train',
                                          with_info=True)
            visualization.show_examples(ds_info, ds)

            ds, ds_info = registered.load('crema_d',
                                          split='validation',
                                          with_info=True)
            visualization.show_examples(ds_info, ds)
コード例 #5
0
 def test_show_examples(self):
   with testing.mock_data():
     builder = registered.builder('imagenet2012')
     visualization.show_statistics(builder.info)
コード例 #6
0
 def test_show_examples_missing_sample(self, _):
     with testing.mock_data(num_examples=3):
         ds, ds_info = load.load('imagenet2012',
                                 split='train',
                                 with_info=True)
     visualization.show_examples(ds.take(3), ds_info)
コード例 #7
0
 def test_show_examples_graph(self, _):
     with testing.mock_data(num_examples=20):
         ds, ds_info = load.load('ogbg_molpcba',
                                 split='train',
                                 with_info=True)
     visualization.show_examples(ds, ds_info)