Пример #1
0
 def test_slice_request_sanity_check_raises_error_on_start_geq_stop(self):
     assert_raises(ValueError,
                   Subset([0, 1, 2], 8)._slice_request_sanity_check,
                   slice(1, 1), 3)
     assert_raises(ValueError,
                   Subset([0, 1, 2], 8)._slice_request_sanity_check,
                   slice(2, 1), 3)
Пример #2
0
 def test_slice_request_sanity_check_raises_error_on_start_geq_num_ex(self):
     assert_raises(ValueError,
                   Subset([0], 8)._slice_request_sanity_check,
                   slice(1, None), 1)
     assert_raises(ValueError,
                   Subset([0], 8)._slice_request_sanity_check,
                   slice(2, None), 1)
Пример #3
0
 def test_add_contiguous_single_step_slice_slice(self):
     assert_equal((Subset(slice(0, 4, 1), 10) +
                   Subset(slice(4, 7, 1), 10)).list_or_slice,
                  slice(0, 7, 1))
     assert_equal((Subset(slice(4, 7, 1), 10) +
                   Subset(slice(0, 4, 1), 10)).list_or_slice,
                  slice(0, 7, 1))
Пример #4
0
 def test_add_overlapping_single_step_slice_slice(self):
     assert_equal((Subset(slice(0, 6, 1), 10) +
                   Subset(slice(4, 7, 1), 10)).list_or_slice,
                  slice(0, 7, 1))
     assert_equal((Subset(slice(4, 7, 1), 10) +
                   Subset(slice(0, 6, 1), 10)).list_or_slice,
                  slice(0, 7, 1))
Пример #5
0
    def get_subsets(h5file, splits, sources):
        """Returns the subsets for a given splits/sources combination.

        Parameters
        ----------
        h5file : HDF5 file handle
            An HDF5 dataset respecting the H5PYDataset interface.
        splits : :class:`tuple` of :class:`str`
            Split names.
        sources : :class:`tuple` of :class:`str`
            Which sources should be considered.

        Returns
        -------
        :class:`list` of :class:`fuel.utils.Subset`
            The subsets, one per source in ``sources``, associated with
            the splits/sources combination.

        """
        subsets = [Subset.empty_subset(len(h5file[source_name]))
                   for source_name in sources]
        for split in splits:
            for i, source in enumerate(sources):
                row, = [r for r in h5file.attrs['split'] if
                        (r['split'].decode('utf8') == split and
                         r['source'].decode('utf8') == source)]
                if row['indices']:
                    subsets[i] += Subset(
                        h5file[row['indices']], len(h5file[source]))
                else:
                    subsets[i] += Subset(
                        slice(row['start'], row['stop']), len(h5file[source]))

        return subsets
Пример #6
0
 def test_raises_value_error_on_indexing_empty_subset(self):
     assert_raises(ValueError,
                   Subset([], 2).index_within_subset, [1, 2], [1])
     assert_raises(ValueError,
                   Subset([], 2).index_within_subset, [1, 2], slice(1, 2))
     assert_raises(ValueError,
                   Subset(slice(0, 0), 2).index_within_subset, [1, 2], [1])
     assert_raises(ValueError,
                   Subset(slice(0, 0), 2).index_within_subset, [1, 2],
                   slice(1, 2))
Пример #7
0
 def test_slice_request_sanity_check_raises_error_on_negative_attr(self):
     assert_raises(ValueError,
                   Subset([0], 8)._slice_request_sanity_check,
                   slice(-1, None, None), 1)
     assert_raises(ValueError,
                   Subset([0], 8)._slice_request_sanity_check,
                   slice(None, -1, None), 1)
     assert_raises(ValueError,
                   Subset([0], 8)._slice_request_sanity_check,
                   slice(None, None, -1), 1)
Пример #8
0
    def load(self):
        # If the dataset is unpickled, it makes no sense to have an external
        # file handle. However, since `load` is also called during the lifetime
        # of a dataset (e.g. if load_in_memory = True), we don't want to
        # accidentally overwrite the reference to a potential external file
        # handle, hence this check.
        if not hasattr(self, '_external_file_handle'):
            self.external_file_handle = None

        self._out_of_memory_open()
        handle = self._file_handle

        # Infer subsets based on `which_sets`
        subsets = self.get_subsets(handle, self.which_sets, self.sources)
        # Sanity check to make sure that all sources have equal length
        if any(subset.num_examples != subsets[0].num_examples
               for subset in subsets):
            raise ValueError("sources have different lengths")
        # Produce the final subsets by taking the `subset` constructor argument
        # into account.
        self.subsets = [
            Subset.subset_of(subset, self.user_given_subset)
            for subset in subsets
        ]

        # Load data sources and source shapes (if requested)
        if self.load_in_memory:
            data_sources = []
            source_shapes = []
            for source_name, subset in zip(self.sources, self.subsets):
                data_sources.append(
                    subset.index_within_subset(handle[source_name],
                                               slice(None)))
                if source_name in self.vlen_sources:
                    shapes = subset.index_within_subset(
                        handle[source_name].dims[0]['shapes'], slice(None))
                else:
                    shapes = None
                source_shapes.append(shapes)
            self.data_sources = tuple(data_sources)
            self.source_shapes = tuple(source_shapes)
            # This exists only for request sanity checking purposes.
            self.in_memory_subset = Subset(slice(None),
                                           len(self.data_sources[0]))
        else:
            self.data_sources = None
            self.source_shapes = None
            self.in_memory_subset = None

        self._out_of_memory_close()
Пример #9
0
    def load(self):
        # If the dataset is unpickled, it makes no sense to have an external
        # file handle. However, since `load` is also called during the lifetime
        # of a dataset (e.g. if load_in_memory = True), we don't want to
        # accidentally overwrite the reference to a potential external file
        # handle, hence this check.
        if not hasattr(self, '_external_file_handle'):
            self.external_file_handle = None

        self._out_of_memory_open()
        handle = self._file_handle

        # Infer subsets based on `which_sets`
        subsets = self.get_subsets(handle, self.which_sets, self.sources)
        # Sanity check to make sure that all sources have equal length
        if any(subset.num_examples != subsets[0].num_examples for subset in
                subsets):
            raise ValueError("sources have different lengths")
        # Produce the final subsets by taking the `subset` constructor argument
        # into account.
        self.subsets = [Subset.subset_of(subset, self.user_given_subset)
                        for subset in subsets]

        # Load data sources and source shapes (if requested)
        if self.load_in_memory:
            data_sources = []
            source_shapes = []
            for source_name, subset in zip(self.sources, self.subsets):
                data_sources.append(
                    subset.index_within_subset(
                        handle[source_name], slice(None)))
                if source_name in self.vlen_sources:
                    shapes = subset.index_within_subset(
                        handle[source_name].dims[0]['shapes'],
                        slice(None))
                else:
                    shapes = None
                source_shapes.append(shapes)
            self.data_sources = tuple(data_sources)
            self.source_shapes = tuple(source_shapes)
            # This exists only for request sanity checking purposes.
            self.in_memory_subset = Subset(
                slice(None), len(self.data_sources[0]))
        else:
            self.data_sources = None
            self.source_shapes = None
            self.in_memory_subset = None

        self._out_of_memory_close()
Пример #10
0
    def __init__(self, indexables, start=None, stop=None, **kwargs):
        if isinstance(indexables, dict):
            self.provides_sources = tuple(indexables.keys())
        else:
            self.provides_sources = ('data',)
        super(IndexableDataset, self).__init__(**kwargs)
        if isinstance(indexables, dict):
            self.indexables = [indexables[source][start:stop]
                               for source in self.sources]
            try:
                if not all(len(indexable) == len(self.indexables[0])
                           for indexable in self.indexables):
                               raise ValueError("sources have different lengths")
            # In case of sparse input
            except TypeError:
                if not all(indexable.shape[0] == self.indexables[0].shape[0]
                           for indexable in self.indexables):
                               raise ValueError("sources have different lengths")

        else:
            self.indexables = [indexables]

        self.example_iteration_scheme = SequentialExampleScheme(
            self.num_examples)

        self.start = start
        self.stop = stop
        self.subset = Subset(slice(start, stop), self.num_examples)
Пример #11
0
    def __init__(self, indexables, start=None, stop=None, **kwargs):
        if isinstance(indexables, dict):
            self.provides_sources = tuple(indexables.keys())
        else:
            self.provides_sources = ('data',)
        super(IndexableDataset, self).__init__(**kwargs)
        if isinstance(indexables, dict):
            self.indexables = [indexables[source][start:stop]
                               for source in self.sources]
            if not all(len(indexable) == len(self.indexables[0])
                       for indexable in self.indexables):
                raise ValueError("sources have different lengths")
        else:
            self.indexables = [indexables]

        self.example_iteration_scheme = SequentialExampleScheme(
            self.num_examples)

        self.start = start
        self.stop = stop
        self.subset = Subset(slice(start, stop), self.num_examples)
Пример #12
0
 def test_add_raises_value_error_when_incompatible(self):
     # Adding two Subset instances should only work when they have the same
     # number of original examples.
     assert_raises(ValueError, operator.add, Subset([1, 3], 10),
                   Subset([2, 4], 11))
Пример #13
0
 def test_slice_subset_slice_request(self):
     assert_equal(Subset(slice(1, 14), 16)[slice(1, 4, 2)], slice(2, 5, 2))
Пример #14
0
 def test_slice_subset_list_request(self):
     assert_equal(Subset(slice(1, 14), 16)[[3, 2, 4]], [4, 3, 5])
Пример #15
0
 def test_list_subset_slice_request(self):
     assert_equal(Subset([0, 2, 5, 7, 10, 15], 16)[slice(1, 4, 2)], [2, 7])
Пример #16
0
 def test_slice_request_sanity_check_raises_error_on_stop_gt_num_ex(self):
     assert_raises(ValueError,
                   Subset([0], 8)._slice_request_sanity_check,
                   slice(None, 2), 1)
Пример #17
0
 def test_adding_slice_slice_falls_back_to_list(self):
     # If Subset can't find a way to add two slices together, it must
     # return a list-based Subset.
     assert_equal((Subset(slice(0, 4), 20) +
                   Subset(slice(12, 16), 20)).list_or_slice,
                  [0, 1, 2, 3, 12, 13, 14, 15])
Пример #18
0
 def test_safe_sorted_fancy_indexing_gt_1(self):
     indexable = numpy.arange(10)
     assert_equal(Subset.sorted_fancy_indexing(indexable, [0, 5, 2]),
                  [0, 5, 2])
Пример #19
0
 def test_slice_num_examples(self):
     assert_equal(Subset(slice(3, 18, 1), 50).num_examples, 15)
Пример #20
0
 def test_list_num_examples(self):
     assert_equal(Subset([0, 3, 8, 13], 15).num_examples, 4)
Пример #21
0
 def test_safe_sorted_fancy_indexing_gt_1(self):
     indexable = numpy.arange(10)
     assert_equal(Subset.sorted_fancy_indexing(indexable, [0, 5, 2]),
                  [0, 5, 2])
Пример #22
0
 def test_list_request_sanity_check_raises_error_on_empty_list(self):
     assert_raises(ValueError,
                   Subset([0], 8)._list_request_sanity_check, [], 1)
Пример #23
0
 def test_list_request_sanity_check_raises_error_on_negative_index(self):
     assert_raises(ValueError,
                   Subset([0], 8)._list_request_sanity_check, [-1], 1)
Пример #24
0
 def test_add_list_list(self):
     assert_equal(
         (Subset([0, 3, 2, 8], 10) + Subset([0, 4, 5], 10)).list_or_slice,
         [0, 2, 3, 4, 5, 8])
Пример #25
0
 def test_add_slice_list(self):
     assert_equal(
         (Subset(slice(1, 5), 10) + Subset([0, 3, 2, 8], 10)).list_or_slice,
         [0, 1, 2, 3, 4, 8])
Пример #26
0
 def test_none_slice_request(self):
     assert_equal(Subset([1, 3, 5, 7], 8)[slice(None)], [1, 3, 5, 7])
     assert_equal(Subset(slice(0, 8, 1), 8)[slice(None)], slice(0, 8, 1))
Пример #27
0
class IndexableDataset(Dataset):
    """Creates a dataset from a set of indexable containers.

    Parameters
    ----------
    indexables : :class:`~collections.OrderedDict` or indexable
        The indexable(s) to provide interface to. This means it must
        support the syntax ```indexable[0]``. If an
        :class:`~collections.OrderedDict` is given, its values should be
        indexables providing data, and its keys strings that are used as
        source names. If a single indexable is given, it will be given the
        source ``data``.

    Attributes
    ----------
    indexables : list
        A list of indexable objects.

    Notes
    -----
    If the indexable data is very large, you might want to consider using
    the :func:`.do_not_pickle_attributes` decorator to make sure the data
    doesn't get pickled with the dataset, but gets reloaded/recreated
    instead.

    This dataset also uses the source names to create properties that
    provide easy access to the data.

    """
    def __init__(self, indexables, start=None, stop=None, **kwargs):
        if isinstance(indexables, dict):
            self.provides_sources = tuple(indexables.keys())
        else:
            self.provides_sources = ('data',)
        super(IndexableDataset, self).__init__(**kwargs)
        if isinstance(indexables, dict):
            self.indexables = [indexables[source][start:stop]
                               for source in self.sources]
            try:
                if not all(len(indexable) == len(self.indexables[0])
                           for indexable in self.indexables):
                               raise ValueError("sources have different lengths")
            # In case of sparse input
            except TypeError:
                if not all(indexable.shape[0] == self.indexables[0].shape[0]
                           for indexable in self.indexables):
                               raise ValueError("sources have different lengths")

        else:
            self.indexables = [indexables]

        self.example_iteration_scheme = SequentialExampleScheme(
            self.num_examples)

        self.start = start
        self.stop = stop
        self.subset = Subset(slice(start, stop), self.num_examples)

    def __getattr__(self, attr):
        if (attr not in ['sources', 'indexables', '_sources'] and
                attr in self.sources):
            return self.indexables[self.sources.index(attr)]
        raise AttributeError

    # Without explicitly defining a trivial __setstate__ method,
    # the __getattribute__ method would call the __getattr__ method,
    # which would raise an AttributeError. This causes problems
    # when unpickling.
    def __setstate__(self, dict):
        self.__dict__ = dict

    @property
    def num_examples(self):
        try:
            return len(self.indexables[0])
        except TypeError:
            if (isinstance(self.indexables[0], scipy.sparse.csr.csr_matrix)):
                out_instead = self.indexables[0].shape[0]
            else:
                out_instead = self.indexables[0]
            return out_instead

    def get_data(self, state=None, request=None):
        if state is not None or request is None:
            raise ValueError
        return tuple(self.subset.index_within_subset(indexable, request)
                     for indexable in self.indexables)
Пример #28
0
 def test_list_subset_list_request(self):
     assert_equal(Subset([0, 2, 5, 7, 10, 15], 16)[[3, 2, 4]], [7, 5, 10])
Пример #29
0
 def test_is_list_property(self):
     assert not Subset(slice(None, None, None), 2).is_list
     assert Subset([0, 1, 3], 4).is_list
Пример #30
0
 def test_list_request_sanity_check_raises_error_on_index_geq_num_ex(self):
     assert_raises(ValueError,
                   Subset([0], 8)._list_request_sanity_check, [1], 1)
     assert_raises(ValueError,
                   Subset([0], 8)._list_request_sanity_check, [2], 1)
Пример #31
0
class H5PYDataset(Dataset):
    """An h5py-fueled HDF5 dataset.

    This dataset class assumes a particular file layout:

    * Data sources reside in the root group, and their names define the
      source names.
    * Data sources are not explicitly split. Instead, splits are defined
      in the `split` attribute of the root group. It's expected to be a
      1D numpy array of compound ``dtype`` with seven fields, organized as
      follows:

      1. ``split`` : string identifier for the split name
      2. ``source`` : string identifier for the source name
      3. ``start`` : start index (inclusive) of the split in the source
         array, used if ``indices`` is a null reference.
      4. ``stop`` : stop index (exclusive) of the split in the source
         array, used if ``indices`` is a null reference.
      5. ``indices`` : h5py.Reference, reference to a dataset containing
         subset indices for this split/source pair. If it's a null
         reference, ``start`` and ``stop`` are used.
      6. ``available`` : boolean, ``False`` is this split is not available
         for this source
      7. ``comment`` : comment string

    Parameters
    ----------
    file_or_path : :class:`h5py.File` or str
        HDF5 file handle, or path to the HDF5 file.
    which_sets : iterable of str
        Which split(s) to use. If one than more split is requested,
        the provided sources will be the intersection of provided
        sources for these splits. **Note: for all splits that are
        specified as a list of indices, those indices will get sorted
        no matter what.**
    subset : {slice, list of int}, optional
        Which subset of data to use *within the context of the split*.
        Can be either a slice or a list of indices. Defaults to `None`,
        in which case the whole split is used.
    load_in_memory : bool, optional
        Whether to load the data in main memory. Defaults to `False`.
    driver : str, optional
        Low-level driver to use. Defaults to `None`. See h5py
        documentation for a complete list of available options.
    sort_indices : bool, optional
        HDF5 doesn't support fancy indexing with an unsorted list of
        indices. In order to allow that, the dataset can sort the list
        of indices, access the data in sorted order and shuffle back
        the data in the unsorted order. Setting this flag to `True`
        (the default) will activate this behaviour. For greater
        performance, set this flag to `False`. Note that in that case,
        it is the user's responsibility to make sure that indices are
        ordered.

    Attributes
    ----------
    sources : tuple of strings
        The sources this dataset will provide when queried for data.
    provides_sources : tuple of strings
        The sources this dataset *is able to* provide for the requested
        split.
    example_iteration_scheme : :class:`.IterationScheme` or ``None``
        The iteration scheme the class uses in order to produce a stream of
        examples.
    vlen_sources : tuple of strings
        All sources provided by this dataset which have variable length.
    default_axis_labels : dict mapping string to tuple of strings
        Maps all sources provided by this dataset to their axis labels.

    """
    interface_version = '0.3'
    _ref_counts = defaultdict(int)
    _file_handles = {}

    def __init__(self, file_or_path, which_sets, subset=None,
                 load_in_memory=False, driver=None, sort_indices=True,
                 **kwargs):
        if isinstance(file_or_path, h5py.File):
            self.path = file_or_path.filename
            self.external_file_handle = file_or_path
        else:
            self.path = file_or_path
            self.external_file_handle = None
        which_sets_invalid_value = (
            isinstance(which_sets, six.string_types) or
            not all(isinstance(s, six.string_types) for s in which_sets))
        if which_sets_invalid_value:
            raise ValueError('`which_sets` should be an iterable of strings')
        self.which_sets = which_sets
        self.user_given_subset = subset if subset else slice(None)
        self.load_in_memory = load_in_memory
        self.driver = driver
        self.sort_indices = sort_indices

        self._parse_dataset_info()

        kwargs.setdefault('axis_labels', self.default_axis_labels)
        super(H5PYDataset, self).__init__(**kwargs)

    def _parse_dataset_info(self):
        """Parses information related to the HDF5 interface.

        In addition to verifying that the `self.which_sets` split is
        available, this method sets the following attributes:

        * `provides_sources`
        * `vlen_sources`
        * `default_axis_labels`

        """
        self._out_of_memory_open()
        handle = self._file_handle
        available_splits = self.get_all_splits(handle)
        which_sets = self.which_sets
        provides_sources = None
        for split in which_sets:
            if split not in available_splits:
                raise ValueError(
                    "'{}' split is not provided by this ".format(split) +
                    "dataset. Available splits are " +
                    "{}.".format(available_splits))
            split_provides_sources = set(
                self.get_provided_sources(handle, split))
            if provides_sources:
                provides_sources &= split_provides_sources
            else:
                provides_sources = split_provides_sources
        self.provides_sources = tuple(sorted(provides_sources))
        self.vlen_sources = self.get_vlen_sources(handle)
        self.default_axis_labels = self.get_axis_labels(handle)
        self._out_of_memory_close()

    @staticmethod
    def create_split_array(split_dict):
        """Create a valid array for the `split` attribute of the root node.

        Parameters
        ----------
        split_dict : dict
            Maps split names to dict. Those dict map source names to
            tuples. Those tuples contain two, three or four elements:
            the start index, the stop index, (optionally) subset
            indices and (optionally) a comment.  If a particular
            split/source combination isn't present in the split dict,
            it's considered as unavailable and the `available` element
            will be set to `False` it its split array entry.

        """
        # Determine maximum split, source and string lengths
        split_len = max(len(split) for split in split_dict)
        sources = set()
        comment_len = 1
        for split in split_dict.values():
            sources |= set(split.keys())
            for val in split.values():
                if len(val) == 4:
                    comment_len = max([comment_len, len(val[-1])])
        sources = sorted(list(sources))
        source_len = max(len(source) for source in sources)

        # Instantiate empty split array
        split_array = numpy.empty(
            len(split_dict) * len(sources),
            dtype=numpy.dtype([
                ('split', 'a', split_len),
                ('source', 'a', source_len),
                ('start', numpy.int64, 1),
                ('stop', numpy.int64, 1),
                ('indices', h5py.special_dtype(ref=h5py.Reference)),
                ('available', numpy.bool, 1),
                ('comment', 'a', comment_len)]))

        # Fill split array
        for i, (split, source) in enumerate(product(split_dict, sources)):
            if source in split_dict[split]:
                start, stop = split_dict[split][source][:2]
                available = True
                indices = h5py.Reference()
                # Workaround for bug when pickling an empty string
                comment = '.'
                if len(split_dict[split][source]) > 2:
                    indices = split_dict[split][source][2]
                if len(split_dict[split][source]) > 3:
                    comment = split_dict[split][source][3]
                    if not comment:
                        comment = '.'
            else:
                (start, stop, indices, available, comment) = (
                    0, 0, h5py.Reference(), False, '.')
            # Workaround for H5PY being unable to store unicode type
            split_array[i]['split'] = split.encode('utf8')
            split_array[i]['source'] = source.encode('utf8')
            split_array[i]['start'] = start
            split_array[i]['stop'] = stop
            split_array[i]['indices'] = indices
            split_array[i]['available'] = available
            split_array[i]['comment'] = comment.encode('utf8')

        return split_array

    @staticmethod
    def get_all_splits(h5file):
        """Returns the names of all splits of an HDF5 dataset.

        Parameters
        ----------
        h5file : HDF5 file handle
            An HDF5 dataset respecting the H5PYDataset interface.

        Returns
        -------
        available_splits : tuple of str
            Names of all splits in ``h5file``.

        """
        available_splits = tuple(
            set(row['split'].decode('utf8') for row in h5file.attrs['split']))
        return available_splits

    @staticmethod
    def get_all_sources(h5file):
        """Returns the names of all sources of an HDF5 dataset.

        Parameters
        ----------
        h5file : HDF5 file handle
            An HDF5 dataset respecting the H5PYDataset interface.

        Returns
        -------
        all_sources : tuple of str
            Names of all sources in ``h5file``.

        """
        all_sources = tuple(
            set(row['source'].decode('utf8') for row in h5file.attrs['split']))
        return all_sources

    @staticmethod
    def get_provided_sources(h5file, split):
        """Returns the sources provided by a specific split.

        Parameters
        ----------
        h5file : HDF5 file handle
            An HDF5 dataset respecting the H5PYDataset interface.
        split : str
            Name of the split.

        Returns
        -------
        provided_sources : tuple of str
            Names of sources provided by ``split`` in ``h5file``.

        """
        provided_sources = tuple(
            row['source'].decode('utf8') for row in h5file.attrs['split']
            if row['split'].decode('utf8') == split and row['available'])
        return provided_sources

    @staticmethod
    def get_vlen_sources(h5file):
        """Returns the names of variable-length sources in an HDF5 dataset.

        Parameters
        ----------
        h5file : HDF5 file handle
            An HDF5 dataset respecting the H5PYDataset interface.
        split : str
            Name of the split.

        Returns
        -------
        vlen_sources : tuple of str
            Names of all variable-length sources in ``h5file``.

        """
        vlen_sources = []
        for source_name in H5PYDataset.get_all_sources(h5file):
            source = h5file[source_name]
            if len(source.dims) > 0 and 'shapes' in source.dims[0]:
                if len(source.dims) > 1:
                    raise ValueError('Variable-length sources must have only '
                                     'one dimension.')
                vlen_sources.append(source_name)
        return vlen_sources

    @staticmethod
    def get_axis_labels(h5file):
        """Returns axis labels for all sources in an HDF5 dataset.

        Parameters
        ----------
        h5file : HDF5 file handle
            An HDF5 dataset respecting the H5PYDataset interface.

        Returns
        -------
        axis_labels : dict
            Maps source names to a tuple of str representing the axis
            labels.

        """
        axis_labels = {}
        vlen_sources = H5PYDataset.get_vlen_sources(h5file)
        for source_name in H5PYDataset.get_all_sources(h5file):
            if source_name in vlen_sources:
                axis_labels[source_name] = (
                    (h5file[source_name].dims[0].label,) +
                    tuple(label.decode('utf8') for label in
                          h5file[source_name].dims[0]['shape_labels']))
            else:
                axis_labels[source_name] = tuple(
                    dim.label for dim in h5file[source_name].dims)
        return axis_labels

    @staticmethod
    def get_subsets(h5file, splits, sources):
        """Returns the subsets for a given splits/sources combination.

        Parameters
        ----------
        h5file : HDF5 file handle
            An HDF5 dataset respecting the H5PYDataset interface.
        splits : :class:`tuple` of :class:`str`
            Split names.
        sources : :class:`tuple` of :class:`str`
            Which sources should be considered.

        Returns
        -------
        :class:`list` of :class:`fuel.utils.Subset`
            The subsets, one per source in ``sources``, associated with
            the splits/sources combination.

        """
        subsets = [Subset.empty_subset(len(h5file[source_name]))
                   for source_name in sources]
        for split in splits:
            for i, source in enumerate(sources):
                row, = [r for r in h5file.attrs['split'] if
                        (r['split'].decode('utf8') == split and
                         r['source'].decode('utf8') == source)]
                if row['indices']:
                    subsets[i] += Subset(
                        h5file[row['indices']], len(h5file[source]))
                else:
                    subsets[i] += Subset(
                        slice(row['start'], row['stop']), len(h5file[source]))

        return subsets

    def load(self):
        # If the dataset is unpickled, it makes no sense to have an external
        # file handle. However, since `load` is also called during the lifetime
        # of a dataset (e.g. if load_in_memory = True), we don't want to
        # accidentally overwrite the reference to a potential external file
        # handle, hence this check.
        if not hasattr(self, '_external_file_handle'):
            self.external_file_handle = None

        self._out_of_memory_open()
        handle = self._file_handle

        # Infer subsets based on `which_sets`
        subsets = self.get_subsets(handle, self.which_sets, self.sources)
        # Sanity check to make sure that all sources have equal length
        if any(subset.num_examples != subsets[0].num_examples for subset in
                subsets):
            raise ValueError("sources have different lengths")
        # Produce the final subsets by taking the `subset` constructor argument
        # into account.
        self.subsets = [Subset.subset_of(subset, self.user_given_subset)
                        for subset in subsets]

        # Load data sources and source shapes (if requested)
        if self.load_in_memory:
            data_sources = []
            source_shapes = []
            for source_name, subset in zip(self.sources, self.subsets):
                data_sources.append(
                    subset.index_within_subset(
                        handle[source_name], slice(None)))
                if source_name in self.vlen_sources:
                    shapes = subset.index_within_subset(
                        handle[source_name].dims[0]['shapes'],
                        slice(None))
                else:
                    shapes = None
                source_shapes.append(shapes)
            self.data_sources = tuple(data_sources)
            self.source_shapes = tuple(source_shapes)
            # This exists only for request sanity checking purposes.
            self.in_memory_subset = Subset(
                slice(None), len(self.data_sources[0]))
        else:
            self.data_sources = None
            self.source_shapes = None
            self.in_memory_subset = None

        self._out_of_memory_close()

    @property
    def num_examples(self):
        return self.subsets[0].num_examples

    def open(self):
        return None if self.load_in_memory else self._out_of_memory_open()

    def _out_of_memory_open(self):
        if not self.external_file_handle:
            if self.path not in self._file_handles:
                handle = h5py.File(
                    name=self.path, mode="r", driver=self.driver)
                self._file_handles[self.path] = handle
            self._ref_counts[self.path] += 1

    def close(self, state):
        if not self.load_in_memory:
            self._out_of_memory_close()

    def _out_of_memory_close(self):
        if not self.external_file_handle:
            self._ref_counts[self.path] -= 1
            if not self._ref_counts[self.path]:
                del self._ref_counts[self.path]
                self._file_handles[self.path].close()
                del self._file_handles[self.path]

    @property
    def _file_handle(self):
        if self.external_file_handle:
            return self.external_file_handle
        elif self.path in self._file_handles:
            return self._file_handles[self.path]
        else:
            raise IOError('no open handle for file {}'.format(self.path))

    def get_data(self, state=None, request=None):
        if self.load_in_memory:
            data, shapes = self._in_memory_get_data(state, request)
        else:
            data, shapes = self._out_of_memory_get_data(state, request)
        for i in range(len(data)):
            if shapes[i] is not None:
                for j in range(len(data[i])):
                    data[i][j] = data[i][j].reshape(shapes[i][j])
        return tuple(data)

    def _in_memory_get_data(self, state=None, request=None):
        if state is not None or request is None:
            raise ValueError
        data = [self.in_memory_subset.index_within_subset(data_source, request)
                for data_source in self.data_sources]
        shapes = [self.in_memory_subset.index_within_subset(shape, request)
                  if shape is not None else None
                  for shape in self.source_shapes]
        return data, shapes

    def _out_of_memory_get_data(self, state=None, request=None):
        if not isinstance(request, (slice, list)):
            raise ValueError()
        data = []
        shapes = []
        handle = self._file_handle
        for source_name, subset in zip(self.sources, self.subsets):
            # Process the data request within the context of the data source
            # subset
            data.append(
                subset.index_within_subset(
                    handle[source_name], request,
                    sort_indices=self.sort_indices))
            # If this source has variable length, get the shapes as well
            if source_name in self.vlen_sources:
                shapes.append(
                    subset.index_within_subset(
                        handle[source_name].dims[0]['shapes'], request,
                        sort_indices=self.sort_indices))
            else:
                shapes.append(None)
        return data, shapes
Пример #32
0
 def test_lists_are_unique_and_sorted(self):
     assert_equal(Subset([0, 3, 3, 5], 10).list_or_slice, [0, 3, 5])
     assert_equal(Subset([0, 3, 1, 5], 10).list_or_slice, [0, 1, 3, 5])
Пример #33
0
class IndexableDataset(Dataset):
    """Creates a dataset from a set of indexable containers.

    Parameters
    ----------
    indexables : :class:`~collections.OrderedDict` or indexable
        The indexable(s) to provide interface to. This means it must
        support the syntax ```indexable[0]``. If an
        :class:`~collections.OrderedDict` is given, its values should be
        indexables providing data, and its keys strings that are used as
        source names. If a single indexable is given, it will be given the
        source ``data``.

    Attributes
    ----------
    indexables : list
        A list of indexable objects.

    Notes
    -----
    If the indexable data is very large, you might want to consider using
    the :func:`.do_not_pickle_attributes` decorator to make sure the data
    doesn't get pickled with the dataset, but gets reloaded/recreated
    instead.

    This dataset also uses the source names to create properties that
    provide easy access to the data.

    """
    def __init__(self, indexables, start=None, stop=None, **kwargs):
        if isinstance(indexables, dict):
            self.provides_sources = tuple(indexables.keys())
        else:
            self.provides_sources = ('data', )
        super(IndexableDataset, self).__init__(**kwargs)
        if isinstance(indexables, dict):
            self.indexables = [
                indexables[source][start:stop] for source in self.sources
            ]
            if not all(
                    len(indexable) == len(self.indexables[0])
                    for indexable in self.indexables):
                raise ValueError("sources have different lengths")
        else:
            self.indexables = [indexables]

        self.example_iteration_scheme = SequentialExampleScheme(
            self.num_examples)

        self.start = start
        self.stop = stop
        self.subset = Subset(slice(start, stop), self.num_examples)

    def __getattr__(self, attr):
        if (attr not in ['sources', 'indexables', '_sources']
                and attr in self.sources):
            return self.indexables[self.sources.index(attr)]
        raise AttributeError

    # Without explicitly defining a trivial __setstate__ method,
    # the __getattribute__ method would call the __getattr__ method,
    # which would raise an AttributeError. This causes problems
    # when unpickling.
    def __setstate__(self, dict):
        self.__dict__ = dict

    @property
    def num_examples(self):
        return len(self.indexables[0])

    def get_data(self, state=None, request=None):
        if state is not None or request is None:
            raise ValueError
        return tuple(
            self.subset.index_within_subset(indexable, request)
            for indexable in self.indexables)
Пример #34
0
 def test_contiguous_lists_are_transformed_into_slices(self):
     assert_equal(Subset([1, 2, 3], 10).list_or_slice, slice(1, 4, None))