def show(
        self,
        ds_info,
        ds,
        rows=2,
        cols=2,
        plot_scale=3.,
        audio_key=None,
    ):
        """Display the audio dataset.

      Args:
        ds_info: `tfds.core.DatasetInfo` object of the dataset to visualize.
        ds: `tf.data.Dataset`. The tf.data.Dataset object to visualize.
        rows: `int`, number of rows of the display grid : Default is 2.
        cols: `int`, number of columns of the display grid : Default is 2.
        plot_scale: `float`, controls the plot size of the images. Keep this
        value around 3 to get a good plot. High and low values may cause
        the labels to get overlapped.
        audio_key: `string`, name of the feature that contains the audio. If not
          set, the system will try to auto-detect it.
      """
        if not audio_key:
            #Auto inferring the audio key
            audio_keys = visualizer.extract_keys(ds_info.features,
                                                 features_lib.Audio)
        key = audio_keys[0]
        # Identifying the sample rate  If None - 16000KHz is used as default
        samplerate = ds_info.features[key].sample_rate
        if not samplerate:
            samplerate = 16000
        _make_audio_grid(ds, key, samplerate, rows, cols, plot_scale)
Exemplo n.º 2
0
    def show(
        self,
        ds_info,
        ds,
        rows=3,
        cols=3,
        plot_scale=3.,
        image_key=None,
    ):
        """Display the dataset.

    Args:
      ds_info: `tfds.core.DatasetInfo` object of the dataset to visualize.
      ds: `tf.data.Dataset`. The tf.data.Dataset object to visualize. Examples
        should not be batched. Examples will be consumed in order until
        (rows * cols) are read or the dataset is consumed.
      rows: `int`, number of rows of the display grid.
      cols: `int`, number of columns of the display grid.
      plot_scale: `float`, controls the plot size of the images. Keep this
        value around 3 to get a good plot. High and low values may cause
        the labels to get overlapped.
      image_key: `string`, name of the feature that contains the image. If not
         set, the system will try to auto-detect it.

    Returns:
      fig: The pyplot figure.
    """
        # Extract the image key
        if not image_key:
            image_keys = visualizer.extract_keys(ds_info.features,
                                                 features_lib.Image)
            if len(image_keys) > 1:
                raise ValueError(
                    'Multiple image features detected in the dataset. '
                    'Use `image_key` argument to override. Images detected: {}'
                    .format(image_keys))
            image_key = image_keys[0]

        # Optionally extract the label key
        label_keys = visualizer.extract_keys(ds_info.features,
                                             features_lib.ClassLabel)
        label_key = label_keys[0] if len(label_keys) == 1 else None
        if not label_key:
            logging.info('Was not able to auto-infer label.')

        # Single image display
        def make_cell_fn(ax, ex):
            plt = lazy_imports_lib.lazy_imports.matplotlib.pyplot

            if not isinstance(ex, dict):
                raise ValueError(
                    '{} requires examples as `dict`, with the same '
                    'structure as `ds_info.features`. It is currently not compatible '
                    'with `as_supervised=True`. Received: {}'.format(
                        type(self).__name__, type(ex)))

            _add_image(ax, ex[image_key])
            if label_key:
                label = ex[label_key]
                label_str = ds_info.features[label_key].int2str(label)
                plt.xlabel('{} ({})'.format(label_str, label))

        # Returns the grid.
        fig = _make_grid(make_cell_fn, ds, rows, cols, plot_scale)
        return fig
Exemplo n.º 3
0
 def match(self, ds_info):
     """See base class."""
     # Supervised required a single image key
     image_keys = visualizer.extract_keys(ds_info.features,
                                          features_lib.Image)
     return len(image_keys) >= 1
 def match(self, ds_info):
     """ See base class."""
     audio_keys = visualizer.extract_keys(ds_info.features,
                                          features_lib.Audio)
     return len(audio_keys) > 0