Exemplo n.º 1
0
def concatenate(data_list,
                out_group=None,
                start=None,
                stop=None,
                datasets=None,
                dataset_filter=None):
    """Concatenate data along the time axis.

    All :class:`TOData` objects to be concatenated are assumed to have the
    same datasets and index_maps with compatible shapes and data types.

    Currently only 'time' axis concatenation is supported, and it must be the
    fastest varying index.

    All attributes, history, and other non-time-dependant information is copied
    from the first item.

    Parameters
    ----------
    data_list : list of :class:`TOData`. These are assumed to be identical in
            every way except along the axes representing time, over which they
            are concatenated. All other data and attributes are simply copied
            from the first entry of the list.
    out_group : `h5py.Group`, hdf5 filename or `memh5.Group`
            Underlying hdf5 like container that will store the data for the
            BaseData instance.
    start : int or dict with keys ``data_list[0].time_axes``
        In the aggregate datasets at what index to start.  Every thing before
        this index is excluded.
    stop : int or dict with keys ``data_list[0].time_axes``
        In the aggregate datasets at what index to stop.  Every thing after
        this index is excluded.
    datasets : sequence of strings
        Which datasets to include.  Default is all of them.
    dataset_filter : callable with one or two arguments
        Function for preprocessing all datasets.  Useful for changing data
        types etc. Takes a dataset as an arguement and should return a
        dataset (either h5py or memh5). Optionally may accept a second
        argument that is slice along the time axis, which the filter should
        apply.

    Returns
    -------
    data : :class:`TOData`

    """

    if dataset_filter is None:

        def dataset_filter(d):
            return d

    filter_time_slice = len(inspect.getargspec(dataset_filter).args) == 2

    # Inspect first entry in the list to get constant parts..
    first_data = data_list[0]
    concatenation_axes = first_data.time_axes

    # Ensure *start* and *stop* are mappings.
    if not hasattr(start, "__getitem__"):
        start = {axis: start for axis in concatenation_axes}
    if not hasattr(stop, "__getitem__"):
        stop = {axis: stop for axis in concatenation_axes}

    # Get the length of all axes for which we are concatenating.
    concat_index_lengths = {axis: 0 for axis in concatenation_axes}
    for data in data_list:
        for index_name in concatenation_axes:
            if index_name not in data.index_map:
                continue
            concat_index_lengths[index_name] += len(data.index_map[index_name])

    # Get real start and stop indexes.
    for axis in concatenation_axes:
        start[axis], stop[axis] = _start_stop_inds(start.get(axis, None),
                                                   stop.get(axis, None),
                                                   concat_index_lengths[axis])

    if first_data.distributed and not isinstance(out_group, h5py.Group):
        distributed = True
        comm = first_data.comm
    else:
        distributed = False
        comm = None

    # Choose return class and initialize the object.
    out = first_data.__class__(out_group, distributed=distributed, comm=comm)

    # Resolve the index maps. XXX Shouldn't be nessisary after fix to
    # _copy_non_time_data.
    for axis, index_map in first_data.index_map.items():
        if axis in concatenation_axes:
            # Initialize the dataset.
            dtype = index_map.dtype
            out.create_index_map(
                axis, np.empty(shape=(stop[axis] - start[axis], ),
                               dtype=dtype))
        else:
            # Just copy it.
            out.create_index_map(axis, index_map)

    # Copy over the reverse maps.
    for axis, reverse_map in first_data.reverse_map.items():
        out.create_reverse_map(axis, reverse_map)

    all_dataset_names = _copy_non_time_data(data_list, out)
    if datasets is None:
        dataset_names = all_dataset_names
    else:
        dataset_names = datasets

    current_concat_index_start = {axis: 0 for axis in concatenation_axes}
    # Now loop over the list and copy the data.
    for data in data_list:
        # Get the concatenation axis lengths for this BaseData.
        current_concat_index_n = {
            axis: len(data.index_map.get(axis, []))
            for axis in concatenation_axes
        }
        # Start with the index_map.
        for axis in concatenation_axes:
            axis_finished = current_concat_index_start[axis] >= stop[axis]
            axis_not_started = (current_concat_index_start[axis] +
                                current_concat_index_n[axis] <= start[axis])
            if axis_finished or axis_not_started:
                continue
            in_slice, out_slice = _get_in_out_slice(
                start[axis],
                stop[axis],
                current_concat_index_start[axis],
                current_concat_index_n[axis],
            )
            out.index_map[axis][out_slice] = data.index_map[axis][in_slice]
        # Now copy over the datasets and flags.
        this_dataset_names = _copy_non_time_data(data)
        for name in this_dataset_names:
            dataset = data[name]
            if name not in dataset_names:
                continue
            attrs = dataset.attrs

            # Figure out which axis we are concatenating over.
            for a in memh5.bytes_to_unicode(attrs["axis"]):
                if a in concatenation_axes:
                    axis = a
                    break
            else:
                msg = "Dataset %s does not have a valid concatenation axis."
                raise ValueError(msg % name)
            # Figure out where we are in that axis and how to slice it.
            axis_finished = current_concat_index_start[axis] >= stop[axis]
            axis_not_started = (current_concat_index_start[axis] +
                                current_concat_index_n[axis] <= start[axis])
            if axis_finished or axis_not_started:
                continue
            axis_rate = 1  # Place holder for eventual implementation.
            in_slice, out_slice = _get_in_out_slice(
                start[axis] * axis_rate,
                stop[axis] * axis_rate,
                current_concat_index_start[axis] * axis_rate,
                current_concat_index_n[axis] * axis_rate,
            )

            # Filter the dataset.
            if filter_time_slice:
                dataset = dataset_filter(dataset, in_slice)
            else:
                dataset = dataset_filter(dataset)
            if hasattr(dataset, "attrs"):
                # Some filters modify the attributes; others return a thing
                # without attributes. So we need to check.
                attrs = dataset.attrs

            # Do this *after* the filter, in case filter changed axis order.
            axis_ind = list(memh5.bytes_to_unicode(attrs["axis"])).index(axis)

            # Slice input data if the filter doesn't do it.
            if not filter_time_slice:
                in_slice = (slice(None), ) * axis_ind + (in_slice, )
                dataset = dataset[in_slice]

            # The time slice filter above will convert dataset from a MemDataset
            # instance to either an MPIArray or np.ndarray (depending on if
            # it is distributed).  Need to convert back to the appropriate
            # subclass of MemDataset for the initialization of output dataset.
            if not isinstance(dataset, memh5.MemDataset):
                if distributed and isinstance(dataset, mpiarray.MPIArray):
                    dataset = memh5.MemDatasetDistributed.from_mpi_array(
                        dataset)
                else:
                    dataset = memh5.MemDatasetCommon.from_numpy_array(dataset)

            # If this is the first piece of data, initialize the output
            # dataset.
            if name not in out:
                shape = dataset.shape
                dtype = dataset.dtype
                full_shape = shape[:axis_ind]
                full_shape += ((stop[axis] - start[axis]) * axis_rate, )
                full_shape += shape[axis_ind + 1:]
                if distributed and isinstance(dataset,
                                              memh5.MemDatasetDistributed):
                    new_dset = out.create_dataset(
                        name,
                        shape=full_shape,
                        dtype=dtype,
                        distributed=True,
                        distributed_axis=dataset.distributed_axis,
                    )
                else:
                    new_dset = out.create_dataset(name,
                                                  shape=full_shape,
                                                  dtype=dtype)
                memh5.copyattrs(attrs, new_dset.attrs)

            out_dset = out[name]
            out_slice = (slice(None), ) * axis_ind + (out_slice, )

            # Copy the data in.
            out_dtype = out_dset.dtype
            if (out_dtype.kind == "V" and not out_dtype.fields
                    and out_dtype.shape
                    and isinstance(out_dset, h5py.Dataset)):
                # Awkward special case for pure subarray dtypes, which h5py and
                # numpy treat differently.
                # Drop down to low level interface. I think this is only
                # nessisary for pretty old h5py.
                from h5py import h5t
                from h5py._hl import selections

                mtype = h5t.py_create(out_dtype)
                mdata = dataset.copy().flat[:]
                mspace = selections.SimpleSelection(
                    (mdata.size // out_dtype.itemsize, )).id
                fspace = selections.select(out_dset.shape, out_slice,
                                           out_dset.id).id
                out_dset.id.write(mspace, fspace, mdata, mtype)
            else:
                out_dset[out_slice] = dataset[:]
        # Increment the start indexes for the next item of the list.
        for axis in current_concat_index_start.keys():
            current_concat_index_start[axis] += current_concat_index_n[axis]

    return out
Exemplo n.º 2
0
Arquivo: tod.py Projeto: yodeng/caput
def concatenate(data_list,
                out_group=None,
                start=None,
                stop=None,
                datasets=None,
                dataset_filter=None):
    """Concatenate data along the time axis.

    All :class:`TOData` objects to be concatenated are assumed to have the
    same datasets and index_maps with compatible shapes and data types.

    Currently only 'time' axis concatenation is supported, and it must be the
    fastest varying index.

    All attributes, history, and other non-time-dependant information is copied
    from the first item.

    Parameters
    ----------
    data_list : list of :class:`TOData`. These are assumed to be identical in
            every way except along the axes representing time, over which they
            are concatenated. All other data and attributes are simply copied
            from the first entry of the list.
    out_group : `h5py.Group`, hdf5 filename or `memh5.Group`
            Underlying hdf5 like container that will store the data for the
            BaseData instance.
    start : int or dict with keys ``data_list[0].time_axes``
        In the aggregate datasets at what index to start.  Every thing before
        this index is excluded.
    stop : int or dict with keys ``data_list[0].time_axes``
        In the aggregate datasets at what index to stop.  Every thing after
        this index is excluded.
    datasets : sequence of strings
        Which datasets to include.  Default is all of them.
    dataset_filter : callable
        Function for preprocessing all datasets.  Useful for changing data
        types etc.  Should return a dataset.


    Returns
    -------
    data : :class:`TOData`

    """

    if dataset_filter is None:

        def dataset_filter(d):
            return d

    # Inspect first entry in the list to get constant parts..
    first_data = data_list[0]
    concatenation_axes = first_data.time_axes

    # Ensure *start* and *stop* are mappings.
    if not hasattr(start, '__getitem__'):
        start = {axis: start for axis in concatenation_axes}
    if not hasattr(stop, '__getitem__'):
        stop = {axis: stop for axis in concatenation_axes}

    # Get the length of all axes for which we are concatenating.
    concat_index_lengths = {axis: 0 for axis in concatenation_axes}
    for data in data_list:
        for index_name in concatenation_axes:
            if index_name not in data.index_map.keys():
                continue
            concat_index_lengths[index_name] += len(data.index_map[index_name])

    # Get real start and stop indexes.
    for axis in concatenation_axes:
        start[axis], stop[axis] = _start_stop_inds(
            start.get(axis, None),
            stop.get(axis, None),
            concat_index_lengths[axis],
        )

    if first_data.distributed and not isinstance(out_group, h5py.Group):
        distributed = True
        comm = first_data.comm
    else:
        distributed = False
        comm = None

    # Choose return class and initialize the object.
    out = first_data.__class__(out_group, distributed=distributed, comm=comm)

    # Resolve the index maps. XXX Shouldn't be nessisary after fix to
    # _copy_non_time_data.
    for axis, index_map in first_data.index_map.items():
        if axis in concatenation_axes:
            # Initialize the dataset.
            dtype = index_map.dtype
            out.create_index_map(
                axis,
                np.empty(shape=(stop[axis] - start[axis], ), dtype=dtype),
            )
        else:
            # Just copy it.
            out.create_index_map(axis, index_map)

    all_dataset_names = _copy_non_time_data(data_list, out)
    if datasets is None:
        dataset_names = all_dataset_names
    else:
        dataset_names = datasets

    current_concat_index_start = {axis: 0 for axis in concatenation_axes}
    # Now loop over the list and copy the data.
    for data in data_list:
        # Get the concatenation axis lengths for this BaseData.
        current_concat_index_n = {
            axis: len(data.index_map.get(axis, []))
            for axis in concatenation_axes
        }
        # Start with the index_map.
        for axis in concatenation_axes:
            axis_finished = current_concat_index_start[axis] >= stop[axis]
            axis_not_started = (current_concat_index_start[axis] +
                                current_concat_index_n[axis] <= start[axis])
            if axis_finished or axis_not_started:
                continue
            in_slice, out_slice = _get_in_out_slice(
                start[axis],
                stop[axis],
                current_concat_index_start[axis],
                current_concat_index_n[axis],
            )
            out.index_map[axis][out_slice] = data.index_map[axis][in_slice]
        # Now copy over the datasets and flags.
        this_dataset_names = _copy_non_time_data(data)
        for name in this_dataset_names:
            dataset = data[name]
            if name not in dataset_names:
                continue
            attrs = dataset.attrs
            dataset = dataset_filter(dataset)
            if hasattr(dataset, "attrs"):
                # Some filters modify the attributes; others return a thing
                # without attributes. So we need to check.
                attrs = dataset.attrs

            # For now only support concatenation over minor axis.
            for ii, a in enumerate(attrs['axis']):
                if a in concatenation_axes:
                    axis = a
                    axis_ind = ii
                    break
            else:
                msg = "Dataset %s does not have a valid concatenation axis."
                raise ValueError(msg % name)

            axis_finished = current_concat_index_start[axis] >= stop[axis]
            axis_not_started = (current_concat_index_start[axis] +
                                current_concat_index_n[axis] <= start[axis])
            if axis_finished or axis_not_started:
                continue
            # Place holder for eventual implementation.
            axis_rate = 1
            # If this is the first piece of data, initialize the output
            # dataset.
            # out_keys = ['flags/' + n for n in  out.flags.keys()]
            # out_keys += out.datasets.keys()
            if name not in out:
                shape = dataset.shape
                dtype = dataset.dtype
                full_shape = shape[:axis_ind]
                full_shape += ((stop[axis] - start[axis]) * axis_rate, )
                full_shape += shape[axis_ind + 1:]
                if (distributed
                        and isinstance(dataset, memh5.MemDatasetDistributed)):
                    new_dset = out.create_dataset(
                        name,
                        shape=full_shape,
                        dtype=dtype,
                        distributed=True,
                        distributed_axis=dataset.distributed_axis,
                    )
                else:
                    new_dset = out.create_dataset(name,
                                                  shape=full_shape,
                                                  dtype=dtype)
                memh5.copyattrs(attrs, new_dset.attrs)
            out_dset = out[name]
            in_slice, out_slice = _get_in_out_slice(
                start[axis] * axis_rate,
                stop[axis] * axis_rate,
                current_concat_index_start[axis] * axis_rate,
                current_concat_index_n[axis] * axis_rate,
            )
            in_slice = (slice(None), ) * axis_ind + (in_slice, )
            out_slice = (slice(None), ) * axis_ind + (out_slice, )
            # Awkward special case for pure subarray dtypes, which h5py and
            # numpy treat differently.
            out_dtype = out_dset.dtype
            if (out_dtype.kind == 'V' and not out_dtype.fields
                    and out_dtype.shape
                    and isinstance(out_dset, h5py.Dataset)):
                # index_pairs = zip(range(dataset.shape[-1])[in_slice],
                #                   range(out_dset.shape[-1])[out_slice])
                # Drop down to low level interface. I think this is only
                # nessisary for pretty old h5py.
                from h5py import h5t
                from h5py._hl import selections
                mtype = h5t.py_create(out_dtype)
                mdata = dataset[in_slice].copy().flat[:]
                mspace = selections.SimpleSelection(
                    (mdata.size // out_dtype.itemsize, )).id
                fspace = selections.select(out_dset.shape, out_slice,
                                           out_dset.id).id
                out_dset.id.write(mspace, fspace, mdata, mtype)
            else:
                out_dset[out_slice] = dataset[in_slice]
        # Increment the start indexes for the next item of the list.
        for axis in current_concat_index_start.keys():
            current_concat_index_start[axis] += current_concat_index_n[axis]

    return out