示例#1
0
class H5Writer:
    L = TypeVar("L", List[ndarray], Dict[str, ndarray])

    #__H5PY: h5py.File = h5py.File(fileobj=None, mode=None)
    def __init__(self, filename: str) -> None:
        self.__file = File(filename, 'a')

    def saveImgDataIntoGroup(self, imgData: L, groupName: str,
                             datasetNames: List[str]) -> None:
        #with File(filename, 'a') as file:
        group: Group = self.__file.create_group(groupName)
        print('... group was created successfully!')
        assert (len(imgData) == len(datasetNames)
                ), 'the number of data to save and data set names are no equal'
        for i in range(len(datasetNames)):
            group.create_dataset(datasetNames[i],
                                 data=asarray(imgData[i]),
                                 compression='gzip',
                                 compression_opts=9)
            print('... dataset was created successfully!')

    def loadImgDataFromGroup(self,
                             groupName: str = None,
                             datasetNames: str = None) -> Generator:
        # with File(filename, "r") as file:
        keys: List[str] = list(self.__file.keys())
        imgArray: ndarray = None
        if (keys):
            print(keys)
        else:
            pass
        if (groupName):
            group: Group = self.__file.get(
                groupName)  # group2 = hf.get('group2/subfolder')
            items: List[Tuple] = list(
                group.items()
            )  # [(u'data3', <HDF5 dataset "data3": shape (100, 3333), type "<f8">)]
            if (items):
                print(items)
                try:
                    for i in range(len(items)):
                        print('recovering group class:',
                              group.get(items[i][0]))
                        yield asarray(
                            group.get(items[i][0])
                        )  # n1 = group1.get('data1') \n np.array(n1).shape
                except StopIteration:
                    self.closingH5PY()
            else:
                pass
        elif (datasetNames):
            #cls.__closingH5PY(file)
            yield self.__file.get(
                datasetNames)  # n1 = group1.get('data1') \n np.array(n1).shape

            #print("Size of List", len(imgArray), "Size of Tuple", len(imgArray[0]), "Size of Array", imgArray[0][0].shape, imgArray[0][0].size)

    def closingH5PY(self) -> None:
        self.__file.close()
示例#2
0
    def test_copy_file_with_external_links(self):
        # Setup all the data we need
        start_time = datetime(2017, 4, 3, 11, 0, 0)
        create_date = datetime(2017, 4, 15, 12, 0, 0)
        data = np.arange(1000).reshape((100, 10))
        timestamps = np.arange(100)
        # Create the first file
        nwbfile1 = NWBFile(source='PyNWB tutorial',
                           session_description='demonstrate external files',
                           identifier='NWBE1',
                           session_start_time=start_time,
                           file_create_date=create_date)

        test_ts1 = TimeSeries(name='test_timeseries',
                              source='PyNWB tutorial',
                              data=data,
                              unit='SIunit',
                              timestamps=timestamps)
        nwbfile1.add_acquisition(test_ts1)
        # Write the first file
        self.io[0].write(nwbfile1)
        nwbfile1_read = self.io[0].read()

        # Create the second file
        nwbfile2 = NWBFile(source='PyNWB tutorial',
                           session_description='demonstrate external files',
                           identifier='NWBE1',
                           session_start_time=start_time,
                           file_create_date=create_date)

        test_ts2 = TimeSeries(name='test_timeseries',
                              source='PyNWB tutorial',
                              data=nwbfile1_read.get_acquisition('test_timeseries').data,
                              unit='SIunit',
                              timestamps=timestamps)
        nwbfile2.add_acquisition(test_ts2)
        # Write the second file
        self.io[1].write(nwbfile2)
        self.io[1].close()
        self.io[0].close()  # Don't forget to close the first file too

        # Copy the file
        self.io[2].close()
        HDF5IO.copy_file(source_filename=self.test_temp_files[1].name,
                         dest_filename=self.test_temp_files[2].name,
                         expand_external=True,
                         expand_soft=False,
                         expand_refs=False)

        # Test that everything is working as expected
        # Confirm that our original data file is correct
        f1 = File(self.test_temp_files[0].name)
        self.assertTrue(isinstance(f1.get('/acquisition/test_timeseries/data', getlink=True), HardLink))
        # Confirm that we successfully created and External Link in our second file
        f2 = File(self.test_temp_files[1].name)
        self.assertTrue(isinstance(f2.get('/acquisition/test_timeseries/data', getlink=True), ExternalLink))
        # Confirm that we successfully resolved the External Link when we copied our second file
        f3 = File(self.test_temp_files[2].name)
        self.assertTrue(isinstance(f3.get('/acquisition/test_timeseries/data', getlink=True), HardLink))
示例#3
0
def get_h5(file_name, check=False, index=None):
    h5_object = File(ROOT_DIR + file_name, 'r')
    data = h5_object.get('data')
    if check:
        time = h5_object.get('dates')
        time_obj = time[()].view('<M8[D]')
    else:
        time = None
        time_obj = None
    if index != None:
        data = data[index]
    return data, time, time_obj
def get_internal_hdf():
    try:
        hdf_in = File(config.data.HDFS_INTERNAL_DATA_FILENAME, "r")
    except IOError:
        print("The internal file failed to open for read.")
    else:
        x_trnset = hdf_in.get("/dataset/train/x_trainset")
        y_trnset = hdf_in.get("/dataset/train/y_trainset")
        x_valset = hdf_in.get("/dataset/val/x_valset")
        y_valset = hdf_in.get("/dataset/val/y_valset")
        x_testset = hdf_in.get("/dataset/test/x_testset")
        y_testset = hdf_in.get("/dataset/test/y_testset")
        return x_trnset, y_trnset, x_valset, y_valset, x_testset, y_testset, hdf_in
def get_external_hdf():
    try:
        hdf_ext = File(
            config.external_data_sources.HDFS_EXTERNAL_DATA_FILENAME, "r")
        return hdf_ext.get("external_dataset"), hdf_ext
    except IOError:
        print("The external file failed to open for read.")
示例#6
0
文件: test_io_hdf5.py 项目: t-b/hdmf
 def check_fields(self):
     f = File(self.path)
     self.assertIn('test_bucket', f)
     bucket = f.get('test_bucket')
     self.assertIn('foo_holder', bucket)
     holder = bucket.get('foo_holder')
     self.assertIn('foo1', holder)
     return f
示例#7
0
 def importAsLocalDataset(
     self, project_file: h5py.File, progress_signal: Callable[[int], None] = lambda x: None
 ) -> str:
     project = Project(project_file)
     inner_path = project.local_data_group.name + "/" + self.legacy_datasetId
     if project_file.get(inner_path) is not None:
         return inner_path
     self.dumpToHdf5(h5_file=project_file, inner_path=inner_path, progress_signal=progress_signal)
     return inner_path
示例#8
0
 def _get_hash_dict_from_paths(hdf: h5py.File,
                               paths: List[str]) -> Dict[int, str]:
     hash_dict = {}
     for path in paths:
         g = hdf.get(path)
         hash = HDU.get_attr(g, 'hash', None)
         if hash is None:
             logger.warning(f'Fit at {path} had no hash')
         else:
             hash_dict[hash] = path
     return hash_dict
示例#9
0
    def test_copy_file_with_external_links(self):

        # Setup all the data we need
        foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14)
        bucket1 = FooBucket('test_bucket1', [foo1])

        foofile1 = FooFile('test_foofile1', buckets=[bucket1])

        # Write the first file
        self.io[0].write(foofile1)
        bucket1_read = self.io[0].read()

        # Create the second file

        foo2 = Foo('foo2', bucket1_read.buckets[0].foos[0].my_data, "I am foo2", 34, 6.28)

        bucket2 = FooBucket('test_bucket2', [foo2])
        foofile2 = FooFile('test_foofile2', buckets=[bucket2])
        # Write the second file
        self.io[1].write(foofile2)
        self.io[1].close()
        self.io[0].close()  # Don't forget to close the first file too

        # Copy the file
        self.io[2].close()
        HDF5IO.copy_file(source_filename=self.test_temp_files[1],
                         dest_filename=self.test_temp_files[2],
                         expand_external=True,
                         expand_soft=False,
                         expand_refs=False)

        # Test that everything is working as expected
        # Confirm that our original data file is correct
        f1 = File(self.test_temp_files[0])
        self.assertIsInstance(f1.get('/buckets/test_bucket1/foo_holder/foo1/my_data', getlink=True), HardLink)
        # Confirm that we successfully created and External Link in our second file
        f2 = File(self.test_temp_files[1])
        self.assertIsInstance(f2.get('/buckets/test_bucket2/foo_holder/foo2/my_data', getlink=True), ExternalLink)
        # Confirm that we successfully resolved the External Link when we copied our second file
        f3 = File(self.test_temp_files[2])
        self.assertIsInstance(f3.get('/buckets/test_bucket2/foo_holder/foo2/my_data', getlink=True), HardLink)
示例#10
0
def get_h5_data(reflection_param_h5, fields):
    """
    Collects the data in the HDF5 dataset saved by 'acenvgenmodel_collect_results.py'. Returns
    them in terms of the array containing the data for each field, the shape of the data and the
    names corresponding to each acoustic environment being modeled.

    Args:
        reflection_param_h5: The location of the saved HDF5 dataset
        fields: The field names to extract the values for from the HDF5 dataset

    Returns:
        The values of the data in the fields as numpy arrays
        The shape of the descriptor for each environment per field
        The name of each environment

    """
    import numpy as np
    from h5py import File
    print('Reading :' + reflection_param_h5)
    hf = File(reflection_param_h5, 'r')
    field_values = []
    for i, the_field in enumerate(fields):
        field_values.append(np.array(hf.get(the_field), dtype=float))
        if the_field == 'y':
            if not model_hits:
                field_values[-1] = field_values[-1][:, :, [0, 1]]
            else:
                field_values[-1] = field_values[-1][:, :, 0:-1]
    names = np.array(hf.get('names')).flatten()
    hf.close()
    if not np.all(
            np.equal(field_values[0].shape[0], [
                field_values[ii].shape[0]
                for ii in range(1, len(field_values))
            ])):
        raise AssertionError('Data lengths don\'t match')
    field_shapes = [
        field_values[ii].shape[1:] for ii in range(len(field_values))
    ]
    print('Got ' + str(field_values[0].shape[0]) + ' entries')
    return field_values, field_shapes, names
示例#11
0
def get_phase(args):
    
    filename = args[0]
    path = args[1]
    path_raw = args[2]
    path_images = args[3]
    mask = args[4]
    coord = args[5]

    file_in = os.path.join(path,filename)
    file_raw = os.path.join(path_raw,'raw_'+filename)
    image_phase = os.path.join(path_images,'wrapped'+filename[4:11]+'bmp')
    binary_phase = os.path.join(path_raw,'wrapped'+filename[4:11]+'dat')
    mod_arr = os.path.join(path_raw,'mod'+filename[4:11]+'dat')
    mod_image = os.path.join(path_images,'mod'+filename[4:11]+'bmp')
    qual_arr = os.path.join(path_raw,'qual'+filename[4:11]+'dat')
    qual_image = os.path.join(path_images,'qual'+filename[4:11]+'bmp')
    # Open meas file and grab dataset
    try:
        f = File(file_in, 'r')
    except:
        print 'Corrupt h5 file: '+filename+' ignoring'
        return
    sub = f.get(r'measurement0/frames/frame_full/data')
    data = np.array(sub[coord[0]-1:coord[1]+1,coord[2]-1:coord[3]+1],'f')
    f.close()

    # Get phase
    phase, modulation, intensity = calc_phase(data)
    # Apply mask
    phase[~mask] = 0
    intensity[~mask] = 0
    modulation[~mask] = 0
    #phase = phase[coord[0]:coord[1],coord[2]:coord[3]]
    
    # Save phase
    toimage(phase).save(image_phase)
    phase.tofile(binary_phase)


    ave_mod = np.average(modulation[mask])
    ave_int = np.average(intensity[mask])
    '''
    if ave_mod < 0.6:
        print filename+' low mod:', ave_mod
    else:
        sys.stdout.write('.')
    '''
    return "%s,%f,%f\n" % (filename, ave_int, ave_mod)
示例#12
0
 def test_nwbio(self):
     io = HDF5IO(self.path, self.manager)
     io.write(self.container)
     io.close()
     f = File(self.path)
     self.assertIn('acquisition', f)
     self.assertIn('analysis', f)
     self.assertIn('general', f)
     self.assertIn('processing', f)
     self.assertIn('file_create_date', f)
     self.assertIn('identifier', f)
     self.assertIn('session_description', f)
     self.assertIn('session_start_time', f)
     acq = f.get('acquisition')
     self.assertIn('test_timeseries', acq)
示例#13
0
def load_coach_ranks(ranking: h5py.File, coach_nick: str):
    coach = load_coach(ranking, coach_nick)

    for mus, phis, period in zip(coach['mu'][:], coach['phi'][:],
                                 ranking['date'][:]):
        if all(np.isnan(mus)):
            break
        for mu, phi, race in zip(mus, phis, ranking.get('race_ids')):
            if not np.isnan(mu):
                yield (
                    coach_nick,
                    nafstat.races.RACES.by_race[race.decode('utf-8')].race_id,
                    mu,
                    phi,
                    period,
                )
示例#14
0
def find_links(
    input_file: h5py.File,
    link_names: Optional[List] = [],
    link_paths: Optional[List] = [],
    path: Optional[str] = None,
) -> (List[str], List[str]):
    """
    Recursively finds all the links in the snapshot and writes them to a list

    Parameters
    ----------
    input_file : h5py.File
        hdf5 file handle for snapshot
    link_names : list of str, optional
        names of links found in the snapshot
    link_paths : list of str, optional
        paths where links found in the snapshot point to
    path : str, optional
        the path to the current location in the snapshot

    Returns
    -------
    link_names, link_paths : list of str, list of str
        lists of the names and links of paths in `input_file`
    """
    if path is not None:
        keys = input_file[path].keys()
    else:
        keys = input_file.keys()
        path = ""

    link_names = []
    link_paths = []
    for key in keys:
        subpath = f"{path}/{key}"
        dataset = input_file.get(subpath, getlink=True)
        if isinstance(dataset, h5py.SoftLink):
            link_names.append(subpath.lstrip("/"))
            link_paths.append(dataset.path)
        else:
            try:
                if input_file[subpath].keys() is not None:
                    find_links(input_file, link_names, link_paths, subpath)
            except:
                pass

    return link_names, link_paths
示例#15
0
文件: database.py 项目: holm10/UEDGE
def create_dict(dump,variables=None):
    ''' Creates a dictionary of all variables in 'dump'
        Parameters:
            dump        Path to dump to be dictionarized
            variables   List of variables to be saved (default is all) - no packages
    '''
    from numpy import array
    ret=dict()          # Dictionary to be returned
    f=File(dump,'r')    # Open readable dump as f
    for pack in f.keys():   # Loop through the list of packages
        p=f.get(pack)       # Get the package object p
        for var in p.keys():    # Loop through all variables in the package
            if variables is None:   # By default, save all variables to dict
                ret[pack+'.'+var]=array(p.get(var))    # Store the variable to the dictionary as array
            elif var in variables:  # If variables listed, save only listed variables
                ret[pack+'.'+var]=array(p.get(var))    # Store the variable to the dictionary as array
    return ret
示例#16
0
    def get_array_from_hdf(self, hdf: h5py.File):
        """
        NOT EASY TO CACHE READ (See self.calculate_from_raw_data)
        Same as self.calculate_array but handles getting data from HDF (Note: reads are uncached!)
        Opens the dataset, applies multiply/offset/bad_rows etc and returns the result
        Args:
            hdf (): HDF file the data exists in

        Returns:
            (np.ndarray): Array of data after necessary modification
        """
        dataset = hdf.get(self.data_path)
        assert isinstance(dataset, h5py.Dataset)
        good_slice = self._good_slice(dataset.shape)
        data = dataset[good_slice]
        data = self._calculate_offset_multiple(data)
        return data
示例#17
0
 def check_fields(self):
     f = File(self.path)
     self.assertIn('acquisition', f)
     self.assertIn('analysis', f)
     self.assertIn('epochs', f)
     self.assertIn('general', f)
     self.assertIn('processing', f)
     self.assertIn('file_create_date', f)
     self.assertIn('identifier', f)
     self.assertIn('session_description', f)
     self.assertIn('nwb_version', f)
     self.assertIn('session_start_time', f)
     acq = f.get('acquisition')
     self.assertIn('images', acq)
     self.assertIn('timeseries', acq)
     ts = acq.get('timeseries')
     self.assertIn('test_timeseries', ts)
示例#18
0
def loadsignal(fname, channel, txtfile=False):
    if txtfile:
        file = fname + ".txt"
        txtdata = np.loadtxt(file)
        data = txtdata[:, channel + 1].reshape(-1, 1)
        srate = 1000
        rawtime = (np.arange(len(data)) / srate).reshape(-1, 1)
        print("txtload")
    else:
        channel = "channel_" + str(channel)
        fname = fname + ".h5"
        h5_object = File(fname)
        h5_group = h5_object.get('00:07:80:58:9B:3F')
        srate = h5_group.attrs.get("sampling rate")
        h5_sub_group = h5_group.get("raw")
        data_chan = h5_sub_group.get(channel)
        data = [item for sublist in data_chan for item in sublist]
        rawtime = np.array(bsnb.generate_time(data, srate))
        rawtime = rawtime.reshape(-1, 1)
    return (np.array(data, dtype=np.float64)).reshape(-1, 1), rawtime, srate
示例#19
0
def read_dense_h5(file, name):
    """
    Load the a dense data array stored via `dense_to_h5`.

    Parameters
    ----------
    file : str
        Name of the h5 file from where the data should be loaded.
    name : str
        Name of the variable stored in the h5 file.

    Returns
    -------
    dense.COO
        Loaded dense array.

    """
    hf = File(file, 'r')
    array = np.array(hf.get(name))
    hf.close()
    return array
示例#20
0
def _update_pdb_dsets(file: h5py.File, name: str,
                      logger: Optional[Logger] = None) -> Optional[PDBContainer]:
    """Check for and update pre dataCAT 0.3 style databases."""
    if not isinstance(file.get(name), h5py.Dataset):
        return None
    elif logger is not None:
        logger.info(f'Updating h5py Dataset to data-CAT >= 0.3 style: {name!r}')

    mol_list = [from_pdb_array(pdb, rdmol=False, warn=False) for pdb in file[name]]
    m = len(mol_list)
    del file[name]

    dtype = IDX_DTYPE[name]
    scale = np.rec.array(None, dtype=dtype, shape=(m,))
    if dtype.fields is not None and scale.size:
        # Ensure that the sentinal value for vlen strings is an empty string, not `None`
        elem = list(scale.item(0))
        iterator = (v for v, *_ in dtype.fields.values())
        for i, sub_dt in enumerate(iterator):
            if h5py.check_string_dtype(sub_dt) is not None:
                elem[i] = ''
        scale[:] = tuple(elem)
    return PDBContainer.from_molecules(mol_list, scale=scale)
示例#21
0
def read_sparse_h5(file, name, shape):
    """
    Load the a sparse data array stored via `sparse_to_h5`.

    Parameters
    ----------
    file : str
        Name of the h5 file from where the data should be loaded.
    name : str
        Name of the variable stored in the h5 file.

    Returns
    -------
    sparse.COO
        Loaded sparse array.

    """
    hf = File(file, 'r')
    raw = np.array(hf.get(name))
    hf.close()
    coords = raw[:-1]
    data = raw[-1]
    return sparse.COO(coords, data, shape=shape)
示例#22
0
文件: mask.py 项目: rockman507/tfi_py
def get_mask(path, border=0, size=True):

    #Get mask shape
    f = File(path, 'r')
    sub = f.get(r'measurement0/maskshapes/Detector')
    sub2 = f.get(r'measurement0/frames/frame_full/data')
    X,Y = sub2.shape
    X -= 2
    Y -= 2
    temp = sub.attrs.get('shape1')
    temp = re.split('[| ,]', temp)
    f.close()

    left = int(float(temp[1]))
    top = int(float(temp[3]))
    width = int(float(temp[5]))
    arr_size = temp[-1]

    # If array is quarter size, expand to same size array as phase
    if float(arr_size) < 750:
        left *=2
        top *=2
        width *=2

    # Create circular mask
    r = (width / 2)
    a = top+r
    b = left+r
    xx,yy=np.ogrid[-a:X-a,-b:Y-b]
    mask = xx*xx + yy*yy < r*r
    '''
    # Set boundaries of mask
    tempx,tempy = np.where(mask)
    y_min = tempy.min()-border
    y_max = tempy.max()+border+1
    x_min = tempx.min()-border
    x_max = tempx.max()+border+1
    '''
    y_min = left+1-border
    y_max = left-1+width+border
    x_min = top+1-border
    x_max = top-1+width+border
    
    #print left, top, width
    #print y_min, y_max, x_min, x_max

    x1 = x_max - x_min
    y1 = y_max - y_min    
    size1 = np.array([x1,y1], dtype='int')
    m1 = mask[x_min:x_max,y_min:y_max]

    
    size_path = os.path.join(os.path.dirname(path), r'size.dat')
    size1.tofile(size_path)
    mask_path = os.path.join(os.path.dirname(path), r'mask.dat')
    m1.tofile(mask_path)


    if size:
        return mask, (x_min, x_max, y_min, y_max)
    else:
        return mask
示例#23
0
def create_new_associations(train_associations,
                            utts_per_channel,
                            bs_h5,
                            verbose=True):
    """
    Creates associations between AIRs and speech files, which to not mix the speech data used
    across different rooms. This is useful in the case where new AIRs are added to the training
    (as possibly a result of data augmentation) and you do not want to add more speech samples
    but to strictly reuse the same ones you used in a previous experiment, in order not to
    introduce any extra variability.

    Args:
        train_associations: A pandas DataFrame with the index being the ACE AIR filename (it can
        be full or just the filename), one column called 'speech', which contains the path to
        the wav speech utterance used and one column called 'offsets', which contains the sample
        offset used for the utterance convolution
        utts_per_channel: Number of speech utterances to associate to each new AIR
        bs_h5: The HDF5 dataset location which contains the new AIR information. The dataset has
        2 fields, one is 'filenames', which contains the locations of the wav AIRs and the other is
        'chan', which indicates the number of channels in the audio file
        verbose: Verbose reporting

    Returns:

    """
    from h5py import File
    import numpy as np
    from random import randint
    try:
        from os.path import basename
    except ImportError:
        raise

    hf = File(bs_h5, 'r')
    wav_files = (np.array(hf.get('filenames')).astype(str)).tolist()

    basenames = [
        thename.split('/')[-1].replace('EE_lobby', 'EE-lobby')
        for thename in wav_files
    ]
    room = [thename.split('_')[1] for thename in basenames]

    basenames_train = [
        thename.split('/')[-1].replace('EE_lobby', 'EE-lobby')
        for thename in train_associations.index
    ]
    room_train = [thename.split('_')[1] for thename in basenames_train]

    new_association = {
        'airs': ([None] * len(wav_files)),
        'speech': [],
        'offsets': []
    }

    if verbose:
        print('Creating speech associations:')
    for i in range(len(wav_files)):
        for _ in range(utts_per_channel):
            same_idxs = np.where(room[i] == np.array(room_train))[-1].tolist()
            new_idx = same_idxs[randint(0, len(same_idxs) - 1)]
            new_association['speech'].append(
                train_associations['speech'][new_idx])
            new_association['offsets'].append(
                train_associations['offsets'][new_idx])
            print(basenames[i] + ' -> ' +
                  train_associations.index[new_idx].split('/')[-1] + ',' +
                  train_associations['speech'][new_idx] + '@' +
                  str(train_associations['offsets'][new_idx]))

    return new_association
示例#24
0
def ref_elasticity(tetra: bool = True, r_lvl: int = 0, out_hdf5: h5py.File = None,
                   xdmf: bool = False, boomeramg: bool = False, kspview: bool = False, degree: int = 1):
    if tetra:
        N = 3 if degree == 1 else 2
        mesh = create_unit_cube(MPI.COMM_WORLD, N, N, N)
    else:
        N = 3
        mesh = create_unit_cube(MPI.COMM_WORLD, N, N, N, CellType.hexahedron)
    for i in range(r_lvl):
        # set_log_level(LogLevel.INFO)
        N *= 2
        if tetra:
            mesh = refine(mesh, redistribute=True)
        else:
            mesh = create_unit_cube(MPI.COMM_WORLD, N, N, N, CellType.hexahedron)
        # set_log_level(LogLevel.ERROR)
    N = degree * N
    fdim = mesh.topology.dim - 1
    V = VectorFunctionSpace(mesh, ("Lagrange", int(degree)))

    # Generate Dirichlet BC on lower boundary (Fixed)
    u_bc = Function(V)
    with u_bc.vector.localForm() as u_local:
        u_local.set(0.0)

    def boundaries(x):
        return np.isclose(x[0], np.finfo(float).eps)
    facets = locate_entities_boundary(mesh, fdim, boundaries)
    topological_dofs = locate_dofs_topological(V, fdim, facets)
    bc = dirichletbc(u_bc, topological_dofs)
    bcs = [bc]

    # Create traction meshtag
    def traction_boundary(x):
        return np.isclose(x[0], 1)
    t_facets = locate_entities_boundary(mesh, fdim, traction_boundary)
    facet_values = np.ones(len(t_facets), dtype=np.int32)
    arg_sort = np.argsort(t_facets)
    mt = meshtags(mesh, fdim, t_facets[arg_sort], facet_values[arg_sort])

    # Elasticity parameters
    E = PETSc.ScalarType(1.0e4)
    nu = 0.1
    mu = Constant(mesh, E / (2.0 * (1.0 + nu)))
    lmbda = Constant(mesh, E * nu / ((1.0 + nu) * (1.0 - 2.0 * nu)))
    g = Constant(mesh, PETSc.ScalarType((0, 0, -1e2)))
    x = SpatialCoordinate(mesh)
    f = Constant(mesh, PETSc.ScalarType(1e4)) * \
        as_vector((0, -(x[2] - 0.5)**2, (x[1] - 0.5)**2))

    # Stress computation
    def sigma(v):
        return (2.0 * mu * sym(grad(v)) + lmbda * tr(sym(grad(v))) * Identity(len(v)))

    # Define variational problem
    u = TrialFunction(V)
    v = TestFunction(V)
    a = inner(sigma(u), grad(v)) * dx
    rhs = inner(g, v) * ds(domain=mesh, subdomain_data=mt, subdomain_id=1) + inner(f, v) * dx

    num_dofs = V.dofmap.index_map.size_global * V.dofmap.index_map_bs
    if MPI.COMM_WORLD.rank == 0:
        print("Problem size {0:d} ".format(num_dofs))

    # Generate reference matrices and unconstrained solution
    bilinear_form = form(a)
    A_org = assemble_matrix(bilinear_form, bcs)
    A_org.assemble()
    null_space_org = rigid_motions_nullspace(V)
    A_org.setNearNullSpace(null_space_org)

    linear_form = form(rhs)
    L_org = assemble_vector(linear_form)
    apply_lifting(L_org, [bilinear_form], [bcs])
    L_org.ghostUpdate(addv=PETSc.InsertMode.ADD_VALUES, mode=PETSc.ScatterMode.REVERSE)
    set_bc(L_org, bcs)
    opts = PETSc.Options()
    if boomeramg:
        opts["ksp_type"] = "cg"
        opts["ksp_rtol"] = 1.0e-5
        opts["pc_type"] = "hypre"
        opts['pc_hypre_type'] = 'boomeramg'
        opts["pc_hypre_boomeramg_max_iter"] = 1
        opts["pc_hypre_boomeramg_cycle_type"] = "v"
        # opts["pc_hypre_boomeramg_print_statistics"] = 1

    else:
        opts["ksp_rtol"] = 1.0e-8
        opts["pc_type"] = "gamg"
        opts["pc_gamg_type"] = "agg"
        opts["pc_gamg_coarse_eq_limit"] = 1000
        opts["pc_gamg_sym_graph"] = True
        opts["mg_levels_ksp_type"] = "chebyshev"
        opts["mg_levels_pc_type"] = "jacobi"
        opts["mg_levels_esteig_ksp_type"] = "cg"
        opts["matptap_via"] = "scalable"
        opts["pc_gamg_square_graph"] = 2
        opts["pc_gamg_threshold"] = 0.02
    # opts["help"] = None # List all available options
    # opts["ksp_view"] = None # List progress of solver

    # Create solver, set operator and options
    solver = PETSc.KSP().create(MPI.COMM_WORLD)
    solver.setFromOptions()
    solver.setOperators(A_org)

    # Solve linear problem
    u_ = Function(V)
    start = perf_counter()
    with Timer("Ref solve"):
        solver.solve(L_org, u_.vector)
    end = perf_counter()
    u_.x.scatter_forward()

    if kspview:
        solver.view()

    it = solver.getIterationNumber()
    if out_hdf5 is not None:
        d_set = out_hdf5.get("its")
        d_set[r_lvl] = it
        d_set = out_hdf5.get("num_dofs")
        d_set[r_lvl] = num_dofs
        d_set = out_hdf5.get("solve_time")
        d_set[r_lvl, MPI.COMM_WORLD.rank] = end - start

    if MPI.COMM_WORLD.rank == 0:
        print("Refinement level {0:d}, Iterations {1:d}".format(r_lvl, it))

    # List memory usage
    mem = sum(MPI.COMM_WORLD.allgather(
        resource.getrusage(resource.RUSAGE_SELF).ru_maxrss))
    if MPI.COMM_WORLD.rank == 0:
        print("{1:d}: Max usage after trad. solve {0:d} (kb)"
              .format(mem, r_lvl))

    if xdmf:

        # Name formatting of functions
        u_.name = "u_unconstrained"
        fname = "results/ref_elasticity_{0:d}.xdmf".format(r_lvl)
        with XDMFFile(MPI.COMM_WORLD, fname, "w") as out_xdmf:
            out_xdmf.write_mesh(mesh)
            out_xdmf.write_function(u_, 0.0, "Xdmf/Domain/Grid[@Name='{0:s}'][1]".format(mesh.name))
示例#25
0
def load_coach(ranking: h5py.File, coach_nick):
    coach = ranking.get("/coaches/{}".format(coach_nick))
    if not coach:
        raise KeyError('load_coach %s not found. (Case sensitive)', coach_nick)
    return {'coach': coach, 'mu': coach['mu'], 'phi': coach['phi']}
def read_air_and_filters_xy(h5_files,
                            framesize=None,
                            get_pow_spec=True,
                            max_air_len=None,
                            fs=None,
                            forced_fs=None,
                            keep_ids=None,
                            start_at_max=True,
                            max_air_read=None):
    latest_file = '../results_dir/training_test_data.h5'
    from os.path import isfile
    import numpy as np
    from h5py import File

    from resampy import resample

    ids = None
    x = None
    all_boudnaries = None

    if forced_fs is None:
        forced_fs = fs
    resample_op = lambda x: x
    if not forced_fs == fs:
        resample_op = lambda x: np.array(
            resample(np.array(x.T, dtype=np.float64), fs, forced_fs, 0)).T

    if max_air_read is not None:
        if fs is None:
            raise AssertionError('Cannot work with max_air_read without fs')
            max_air_read_samples = int(np.ceil(fs * max_air_read))
    for i, this_h5 in enumerate(h5_files):
        print
        "Reading : " + this_h5 + " @ " + str(i + 1) + " of " + str(
            len(h5_files)),
        hf = File(this_h5, 'r')

        names = np.array(hf.get('names'))
        airs = np.array(hf.get('airs')).T

        boundaries = np.array(hf.get('boundary_ids')).T

        if i > 0:
            ids = np.concatenate((ids, names))
        else:
            ids = names

        print("Got " + str(airs.shape))
        airs = resample_op(airs)
        if max_air_read is not None:
            airs = airs[:, 0:max_air_read_samples]
        if i > 0:
            if x.shape[1] < airs.shape[1]:
                npads = -x.shape[1] + airs.shape[1]
                x = np.concatenate((x, np.zeros(
                    (x.shape[0], npads)).astype(x.dtype)),
                                   axis=1)
                x = np.concatenate((x, airs), axis=0)
            else:
                if x.shape[1] > airs.shape[1]:
                    npads = x.shape[1] - airs.shape[1]
                    airs = np.concatenate(
                        (airs, np.zeros(
                            (airs.shape[0], npads)).astype(airs.dtype)),
                        axis=1)
                x.resize((x.shape[0] + airs.shape[0], x.shape[1]))
                x[-airs.shape[0]:, :] = airs
        else:
            x = np.array(airs)

        if i > 0:
            all_boudnaries = np.concatenate((all_boudnaries, boundaries),
                                            axis=0)
        else:
            all_boudnaries = boundaries

    class_names = np.unique(all_boudnaries)
    y = np.zeros((all_boudnaries.shape[0], class_names.size)).astype(bool)

    for i, cname in enumerate(class_names):
        y[np.any(all_boudnaries == cname, axis=1), i] = True

    if keep_ids is not None:
        y = y[:,
              np.in1d(class_names.astype(int),
                      np.array(keep_ids).astype(int))]

    if fs is not None:
        print('Got ' + str(x.shape[0]) + ' AIRs of duration ' +
              str(x.shape[1] / float(fs)))
    else:
        print('Got ' + str(x.shape[0]) + ' AIRs of length ' + str(x.shape[1]))

    x = data_post_proc(x, fs, start_at_max, framesize, get_pow_spec,
                       max_air_len)

    print('Left with ' + str(x.shape) + ' AIRs data ')

    ids = ids.astype(str)
    class_names = class_names.astype(str)
    return (x, y), ids, class_names
示例#27
0
class LabberData:
    """Labber save data in HDF5 files and organize them by channel.

    A channel is either a swept variable or a measured quantities. We use
    either string or integers to identify channels.

    """

    #: Path to the HDF5 file containing the data.
    path: str

    #: Name of the file (ie no directories)
    filename: str = field(init=False)

    #: Private reference to the underlying HDF5 file in which the data are stored
    _file: Optional[File] = field(default=None, init=False)

    #: Groups in which the data of appended measurements are stored.
    _nested: List[Group] = field(default_factory=list, init=False)

    #: Names of the channels accessible in the Data or Traces segments of the file
    _channel_names: Optional[List[str]] = field(default=None, init=False)

    #: Detailed informations about the channels.
    _channels: Optional[List[Union[LogEntry,
                                   StepConfig]]] = field(default=None,
                                                         init=False)

    #: Steps performed in the measurement.
    _steps: Optional[List[StepConfig]] = field(default=None, init=False)

    #: Log entries of the measurement.
    _logs: Optional[List[LogEntry]] = field(default=None, init=False)

    def __post_init__(self) -> None:
        self.filename = self.path.rsplit(os.sep, 1)[-1]

    def open(self) -> None:
        """ Open the underlying HDF5 file."""
        self._file = File(self.path, "r")

        # Identify nested dataset
        i = 2
        nested = []
        while f"Log_{i}" in self._file:
            nested.append(self._file[f"Log_{i}"])
            i += 1

        if nested:
            self._nested = nested

    def close(self) -> None:
        """ Close the underlying HDF5 file and clean cached values."""
        if self._file:
            self._file.close()
        self._file = None
        self._channel_names = None
        self._channels = None
        self._axis_dimensions = None
        self._nested = []
        self._steps = None
        self._logs = None

    def list_steps(self) -> List[StepConfig]:
        """List the different steps of a measurement."""
        if not self._file:
            raise RuntimeError("No file currently opened")

        if not self._steps:
            steps = []
            for (
                    step,
                    _,  # unknown
                    _,  # unknown
                    _,  # 0 no sweep (ie direct update), 1 between points, 2 continuous
                    _,  # unknown
                    has_relation,
                    relation,
                    _,  # unknown
                    _,  # unknown
                    alternate,
                    *_,
            ) in self._file["Step list"]:
                log_configs = [
                    f["Step config"][step]["Step items"]
                    for f in [self._file] + self._nested
                ]

                # A step is considered ramped if it has more than one config in any log,
                # if the first ramp of amy log is ramped, if there is more than one
                # value for a given constant accross different logs
                # The format describing a single config is:
                # ramped, unknown, set value, min, max, center, span, step,
                # number of points, kind(ie linear, log, etc), sweep rate
                is_ramped = (any(len(configs) > 1 for configs in log_configs)
                             or any(
                                 bool(configs[0][0])
                                 for configs in log_configs)
                             or len({configs[0][2]
                                     for configs in log_configs}) > 1)

                # We assume that if we have relations in one log we have them in all
                if has_relation:
                    rel_params = self._file["Step config"][step][
                        "Relation parameters"]
                    relation = (
                        relation,
                        {
                            k: v
                            for k, v, _ in rel_params
                            # Preserve only the parameters useful to the relation
                            # \W is a non word character (no letter no digit)
                            if re.match(r"(.*\W+" + f"{k})|{k}" +
                                        r"(\W+.*|$)", relation)
                        },
                    )
                else:
                    relation = None

                steps.append(
                    StepConfig(
                        name=step,
                        is_ramped=is_ramped,
                        relation=relation,
                        value=None if is_ramped else log_configs[0][0][2],
                        alternate_direction=alternate,
                        ramps=[(RampConfig(
                            start=cfg[3],
                            stop=cfg[4],
                            steps=cfg[8],
                            new_log=bool(i == 0),
                        ) if cfg[0] else RampConfig(
                            start=cfg[2],
                            stop=cfg[2],
                            steps=1,
                            new_log=bool(i == 0),
                        )) for configs in log_configs
                               for i, cfg in enumerate(configs)]
                        if is_ramped else None,
                    ))

            # Mark all channels with relation to a ramped channel as ramped.
            # One can inspect ramps to know if a step is ramped outside of a relation.
            for step in steps:
                if step.relation is not None:
                    step.is_ramped |= any(
                        s.is_ramped for s in steps
                        if s.name in step.relation[1].values())

            self._steps = steps
        return self._steps

    def list_logs(self) -> List[LogEntry]:
        """List the existing log entries in the datafile."""
        if not self._file:
            raise RuntimeError("No file currently opened")

        if not self._logs:
            # Collect all logs names
            names = [e[0] for e in self._file["Log list"]]

            # Identify scalar complex data
            complex_scalars = [
                n for n, v in self._file["Data"]["Channel names"]
                if v == "Real"
            ]

            # Identify vector data
            vectors = self._file.get("Traces", ())

            logs = []
            for n in names:
                if n in vectors:
                    logs.append(
                        LogEntry(
                            name=n,
                            is_vector=True,
                            is_complex=self._file["Traces"][n].attrs.get(
                                "complex"),
                            x_name=self._file["Traces"][n].attrs.get(
                                "x, name"),
                        ))
                else:
                    logs.append(
                        LogEntry(name=n,
                                 is_complex=bool(n in complex_scalars)))
            self._logs = logs

        return self._logs

    def list_channels(self):
        """Identify the channel availables in the Labber file.

        Channels data can be retieved using get_data.

        """
        if self._channel_names is None:
            self._channel_names = [
                s.name for s in self.list_steps() if s.is_ramped
            ] + [l.name for l in self.list_logs()]
            self._channels = [s for s in self.list_steps() if s.is_ramped
                              ] + [l for l in self.list_logs()]

        return self._channel_names

    @overload
    def get_data(
        self,
        name_or_index: Union[str, int],
        filters: Optional[dict] = None,
        filter_precision: float = 1e-10,
        get_x: Literal[False] = False,
    ) -> np.ndarray:
        pass

    @overload
    def get_data(
        self,
        name_or_index: Union[str, int],
        filters: Optional[dict] = None,
        filter_precision: float = 1e-10,
        get_x: Literal[True] = True,
    ) -> Tuple[np.ndarray, np.ndarray]:
        pass

    def get_data(self,
                 name_or_index,
                 filters=None,
                 filter_precision=1e-10,
                 get_x=False):
        """Retrieve data base on channel name or index

        Parameters
        ----------
        name_or_index : str | int
            Name or index of the channel whose data should be loaded.

        filters : dict, optional
            Dictionary specifying channel (as str or int), value pairs to use
            to filter the returned array. For example, passing {0: 1.0} will
            ensure that only the data for which the value of channel 0 is 1.0
            are returned.

        filter_precision : float, optional
            Precision used when comparing the data to the mask value.

        get_x : bool
            Specify for vector data whether the x-data should be returned along with y-data

        Returns
        -------
        data : Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]
            Requested data (or x and requested data for vector data when x is required).
            The data are formatted such as the last axis corresponds the the most inner
            loop of the measurement and the shape match the non filtered steps of the
            measurement.

        """
        if not self._file:
            msg = ("The underlying file needs to be opened before accessing "
                   "data. Either call open or better use a context manager.")
            raise RuntimeError(msg)

        # Convert the name into a channel index.
        index = self._name_or_index_to_index(name_or_index)

        # Get the channel description that contains important details
        # (is_vector, is_complex, )
        if self._channels is None:
            self.list_channels()
        channel = self._channels[index]  # type: ignore

        x_data: List[np.ndarray] = []
        vectorial_data = bool(
            isinstance(channel, LogEntry) and channel.is_vector)
        vec_dim = 0
        if vectorial_data:
            aux = self._get_traces_data(
                channel.name,
                is_complex=channel.is_complex,
                get_x=get_x  # type: ignore
            )
            # Always transpose vector to be able to properly filter them
            if not get_x:
                data = [a.T for a in aux[0]]
            else:
                x_data, data = [a.T for a in aux[0]], [a.T for a in aux[1]]
            vec_dim = data[0].shape[-1]
        else:
            data = self._get_data_data(channel.name,
                                       is_complex=getattr(
                                           channel, "is_complex", False))

        # Filter the data based on the provided filters.
        filters = filters if filters is not None else {}

        # Only use positive indexes to describe filtering
        filters = {
            self._name_or_index_to_index(k): v
            for k, v in filters.items()
        }
        if filters:
            datasets = [self._file] + self._nested[:]
            results = []
            x_results = []
            for i, d in enumerate(datasets):
                masks = []
                for k, v in filters.items():
                    index = self._name_or_index_to_index(k)
                    filter_data = d["Data"]["Data"][:, index]
                    # Create the mask filtering the data
                    mask = np.less(np.abs(filter_data - v), filter_precision)
                    # If the mask is not empty, ensure we do not eliminate nans
                    # added by Labber to have complete sweeps
                    if np.any(mask):
                        mask |= np.isnan(filter_data)
                    masks.append(mask)

                mask = masks.pop()
                for m in masks:
                    mask &= m
                # For unclear reason vector data are not index in the same order
                # as other data and require the mask to be transposed before being
                # raveled
                if vectorial_data:
                    mask = np.ravel(mask.T)

                # Filter
                results.append(data[i][mask])
                if vectorial_data and get_x:
                    x_results.append(x_data[i][mask])

                # If the filtering produces an empty output return early
                if not any(len(r) for r in results):
                    if get_x:
                        return np.empty(0), np.empty(0)
                    else:
                        return np.empty(0)

        else:
            results = data
            x_results = x_data

        # Identify the ramped steps not used for filtering
        steps_points = []
        first_step_is_used = False
        for i, s in reversed(list(enumerate(self.list_steps()))):

            # Use ramps rather than is_ramped to consider only steps manually
            # ramped and exclude steps with multiple values because they
            # they have a relation to a ramped step but no ramp of their own
            if s.ramps is not None and i not in filters:

                # Labber stores scalar data as 3D:
                # - the first dimension is the first step number of points
                # - the second one refer to the channel
                # - the third refer to all other steps but in reverse order
                # For vector data it is the same except that the first is not special.
                if i == 0 and not vectorial_data:
                    steps_points.insert(0, s.points_per_log)
                    first_step_is_used = True
                else:
                    steps_points.append(s.points_per_log)

        if vectorial_data:
            steps_points.append((vec_dim, ) * len(steps_points[0]))

        # Get expected shape per log
        shape_per_logs = np.array(steps_points).T
        shaped_results = []
        shaped_x = []
        for i, shape in enumerate(shape_per_logs):
            # If the filtering produced an empty array skip it
            if results[i].shape == (0, ):
                continue

            # The outer most dimension of the scan corresponds either to the first
            # index if the first step was filtered on (not first_step_is_used) or,
            # otherwise, to the second.
            points_inner_dimensions = (np.prod(
                shape[1:]) if not first_step_is_used else shape[0] *
                                       np.prod(shape[2:]))
            padding = np.prod(results[i].shape) % (points_inner_dimensions)
            # Pad the data to ensure that only the last axis shrinks
            if padding:
                # Compute the number of points to add
                to_add = points_inner_dimensions - padding
                results[i] = np.concatenate(
                    (results[i], np.nan * np.ones(to_add)), None)
                if vectorial_data and get_x:
                    x_results[i] = np.concatenate(
                        (x_results[i], np.nan * np.ones(to_add)), None)

            # Allow the most outer shape to shrink
            new_shape = list(shape)
            new_shape[1 if first_step_is_used and len(shape) > 1 else 0] = -1
            shaped_results.append(results[i].reshape(new_shape))
            if vectorial_data and get_x:
                shaped_x.append(x_results[i].reshape(new_shape))

        # Create the complete data
        full_data = np.concatenate(shaped_results, axis=-1)
        if vectorial_data and get_x:
            full_x = np.concatenate(shaped_x, axis=-1)

        # Move the first axis to last position so that the dimensions are in the reverse
        # order of the steps.
        if first_step_is_used:
            full_data = np.moveaxis(full_data, 0, -1)

        if vectorial_data and get_x:
            return full_x, full_data
        else:
            return full_data

    def _name_or_index_to_index(self, name_or_index: Union[str, int]) -> int:
        """Provide the index of a channel from its name."""
        ch_names = self.list_channels()

        if isinstance(name_or_index, str):
            if name_or_index not in ch_names:
                msg = (f"The specified name ({name_or_index}) does not exist "
                       f"in the dataset. Existing names are {ch_names}")
                raise ValueError(msg)
            return ch_names.index(name_or_index)
        elif name_or_index >= len(ch_names):
            msg = (f"The specified index ({name_or_index}) "
                   f"exceeds the number of channel: {len(ch_names)}")
            raise ValueError(msg)
        else:
            return name_or_index

    def _get_traces_data(self,
                         channel_name: str,
                         is_complex: bool = False,
                         get_x: bool = False) -> Tuple[List[np.ndarray], ...]:
        """Get data stored as traces ie vector."""
        if not self._file:
            raise RuntimeError("No file currently opened")

        if channel_name not in self._file["Traces"]:
            raise ValueError(f"Unknown traces data {channel_name}")

        x_data = []
        data = []
        # Traces dimensions are (sweep, real/imag/x, steps)
        for storage in [self._file] + self._nested:
            if is_complex:
                real = storage["Traces"][channel_name][:, 0, :]
                imag = storage["Traces"][channel_name][:, 1, :]
                data.append(real + 1j * imag)
            else:
                data.append(storage["Traces"][channel_name][:, 0, :])
            if get_x:
                x_data.append(storage["Traces"][channel_name][:, -1, :])

        if get_x:
            return x_data, data
        else:
            return (data, )

    def _get_data_data(self,
                       channel_name: str,
                       is_complex: bool = False) -> List[np.ndarray]:
        """Pull data stored in the data segment of the log file."""
        if not self._file:
            raise RuntimeError("No file currently opened")

        names = [n for n, _ in self._file["Data"]["Channel names"]]
        if is_complex:
            re_index = names.index(channel_name)
            im_index = re_index + 1
            real = self._pull_nested_data(re_index)
            imag = self._pull_nested_data(im_index)
            return [r + 1j * i for r, i in zip(real, imag)]
        else:
            return self._pull_nested_data(names.index(channel_name))

    def _pull_nested_data(self, index: int) -> List[np.ndarray]:
        """Pull data stored in the data segmentfrom all nested logs."""
        if not self._file:
            raise RuntimeError("No file currently opened")
        data = [self._file["Data"]["Data"][:, index]]
        for internal in self._nested:
            data.append(internal["Data"]["Data"][:, index])
        return data

    def __enter__(self) -> "LabberData":
        """ Open the underlying HDF5 file when used as a context manager."""
        self.open()
        return self

    def __exit__(self, exc_type, exc_value, traceback) -> None:
        """ Close the underlying HDF5 file when used as a context manager.

        """
        self.close()
示例#28
0
def reference_periodic(tetra: bool,
                       r_lvl: int = 0,
                       out_hdf5: h5py.File = None,
                       xdmf: bool = False,
                       boomeramg: bool = False,
                       kspview: bool = False,
                       degree: int = 1):
    # Create mesh and finite element
    if tetra:
        # Tet setup
        N = 3
        mesh = create_unit_cube(MPI.COMM_WORLD, N, N, N)
        for i in range(r_lvl):
            mesh.topology.create_entities(mesh.topology.dim - 2)
            mesh = refine(mesh, redistribute=True)
            N *= 2
    else:
        # Hex setup
        N = 3
        for i in range(r_lvl):
            N *= 2
        mesh = create_unit_cube(MPI.COMM_WORLD, N, N, N, CellType.hexahedron)

    V = FunctionSpace(mesh, ("CG", degree))

    # Create Dirichlet boundary condition

    def dirichletboundary(x):
        return np.logical_or(
            np.logical_or(np.isclose(x[1], 0), np.isclose(x[1], 1)),
            np.logical_or(np.isclose(x[2], 0), np.isclose(x[2], 1)))

    mesh.topology.create_connectivity(2, 1)
    geometrical_dofs = locate_dofs_geometrical(V, dirichletboundary)
    bc = dirichletbc(PETSc.ScalarType(0), geometrical_dofs, V)
    bcs = [bc]

    # Define variational problem
    u = TrialFunction(V)
    v = TestFunction(V)
    a = inner(grad(u), grad(v)) * dx
    x = SpatialCoordinate(mesh)
    dx_ = x[0] - 0.9
    dy_ = x[1] - 0.5
    dz_ = x[2] - 0.1
    f = x[0] * sin(5.0 * pi * x[1]) + 1.0 * exp(
        -(dx_ * dx_ + dy_ * dy_ + dz_ * dz_) / 0.02)
    rhs = inner(f, v) * dx

    # Assemble rhs, RHS and apply lifting
    bilinear_form = form(a)
    linear_form = form(rhs)
    A_org = assemble_matrix(bilinear_form, bcs)
    A_org.assemble()
    L_org = assemble_vector(linear_form)
    apply_lifting(L_org, [bilinear_form], [bcs])
    L_org.ghostUpdate(addv=PETSc.InsertMode.ADD_VALUES,
                      mode=PETSc.ScatterMode.REVERSE)
    set_bc(L_org, bcs)

    # Create PETSc nullspace
    nullspace = PETSc.NullSpace().create(constant=True)
    PETSc.Mat.setNearNullSpace(A_org, nullspace)

    # Set PETSc options
    opts = PETSc.Options()
    if boomeramg:
        opts["ksp_type"] = "cg"
        opts["ksp_rtol"] = 1.0e-5
        opts["pc_type"] = "hypre"
        opts['pc_hypre_type'] = 'boomeramg'
        opts["pc_hypre_boomeramg_max_iter"] = 1
        opts["pc_hypre_boomeramg_cycle_type"] = "v"
        # opts["pc_hypre_boomeramg_print_statistics"] = 1
    else:
        opts["ksp_type"] = "cg"
        opts["ksp_rtol"] = 1.0e-12
        opts["pc_type"] = "gamg"
        opts["pc_gamg_type"] = "agg"
        opts["pc_gamg_sym_graph"] = True

        # Use Chebyshev smoothing for multigrid
        opts["mg_levels_ksp_type"] = "richardson"
        opts["mg_levels_pc_type"] = "sor"
    # opts["help"] = None # List all available options
    # opts["ksp_view"] = None # List progress of solver

    # Initialize PETSc solver, set options and operator
    solver = PETSc.KSP().create(MPI.COMM_WORLD)
    solver.setFromOptions()
    solver.setOperators(A_org)

    # Solve linear problem
    u_ = Function(V)
    start = perf_counter()
    with Timer("Solve"):
        solver.solve(L_org, u_.vector)
    end = perf_counter()
    u_.vector.ghostUpdate(addv=PETSc.InsertMode.INSERT,
                          mode=PETSc.ScatterMode.FORWARD)
    if kspview:
        solver.view()

    it = solver.getIterationNumber()
    num_dofs = V.dofmap.index_map.size_global * V.dofmap.index_map_bs
    if out_hdf5 is not None:
        d_set = out_hdf5.get("its")
        d_set[r_lvl] = it
        d_set = out_hdf5.get("num_dofs")
        d_set[r_lvl] = num_dofs
        d_set = out_hdf5.get("solve_time")
        d_set[r_lvl, MPI.COMM_WORLD.rank] = end - start

    if MPI.COMM_WORLD.rank == 0:
        print("Rlvl {0:d}, Iterations {1:d}".format(r_lvl, it))

    # Output solution to XDMF
    if xdmf:
        ext = "tet" if tetra else "hex"
        fname = "results/reference_periodic_{0:d}_{1:s}.xdmf".format(
            r_lvl, ext)
        u_.name = "u_" + ext + "_unconstrained"
        with XDMFFile(MPI.COMM_WORLD, fname, "w") as out_periodic:
            out_periodic.write_mesh(mesh)
            out_periodic.write_function(
                u_, 0.0,
                "Xdmf/Domain/" + "Grid[@Name='{0:s}'][1]".format(mesh.name))
示例#29
0
class LabberData:
    """Labber saves data in HDF5 files and organizes them by channel.

    A channel is either a swept variable or a measured quantity. We use either
    strings or integers to identify channels.
    """

    #: Path to the HDF5 file containing the data.
    path: Union[str, Path]

    #: Name of the file (ie no directories)
    filename: str = field(init=False)

    __file: Optional[File] = field(default=None, init=False)

    @property
    def _file(self):
        """The underlying HDF5 file in which the data are stored."""
        if self.__file is None:
            raise RuntimeError("No HDF5 file is currently opened.")
        return self.__file

    @_file.setter
    def _file(self, value):
        self.__file = value

    #: Groups in which the data of appended measurements are stored.
    _nested: List[Group] = field(default_factory=list, init=False)

    #: Steps performed in the measurement.
    _steps: Optional[List[StepConfig]] = field(default=None, init=False)

    #: Log entries of the measurement.
    _logs: Optional[List[LogEntry]] = field(default=None, init=False)

    def __post_init__(self) -> None:
        if isinstance(self.path, Path):
            self.path = str(self.path)
        self.filename = self.path.rsplit(os.sep, 1)[-1]

    def open(self) -> None:
        """ Open the underlying HDF5 file."""
        self._file = File(self.path, "r")

        # Identify nested dataset
        i = 2
        nested = []
        while f"Log_{i}" in self._file:
            nested.append(self._file[f"Log_{i}"])
            i += 1

        if nested:
            self._nested = nested

    def close(self) -> None:
        """ Close the underlying HDF5 file and clean cached values."""
        if self._file:
            self._file.close()
        self._file = None
        self._nested = []
        self._steps = None
        # clear the @cached_properties if they've been cached
        for attribute in [
            "instrument_configs",
            "channels",
            "channel_names",
            "logs",
            "steps",
        ]:
            try:
                delattr(self, attribute)
            except AttributeError:
                pass

    @cached_property
    def steps(self) -> List[StepConfig]:
        """The step configurations in the Labber file.

        These are listed in the Measurement Editor under "Step sequence" and in the Log
        Browser under "Active channels" with a step list.  The are listed in the
        "/Step list" HDF5 dataset with further information in the "/Step config" group.
        """
        steps = []
        for (
            channel_name,
            _,  # unknown
            _,  # unknown
            _,  # 0 no sweep (ie direct update), 1 between points, 2 continuous
            _,  # unknown
            has_relation,
            relation,
            _,  # unknown
            _,  # unknown
            alternate,
            *_,
        ) in self._file["Step list"]:

            # Decode byte string from the hdf5 file
            channel_name = maybe_decode(channel_name)
            relation = maybe_decode(relation)

            step_configs = [
                f["Step config"][channel_name]["Step items"]
                for f in [self._file] + self._nested
            ]

            # A step is considered ramped if it has more than one config in any log,
            # if the first ramp of any log is ramped, if there is more than one
            # value for a given constant accross different logs
            # The format describing a single config is:
            # ramped, unknown, set value, min, max, center, span, step,
            # number of points, kind(ie linear, log, etc), sweep rate
            is_ramped = (
                any(len(configs) > 1 for configs in step_configs)
                or any(bool(configs[0][0]) for configs in step_configs)
                or len({configs[0][2] for configs in step_configs}) > 1
            )

            # We assume that if we have relations in one log we have them in all
            if has_relation:
                rel_params = {
                    maybe_decode(k): maybe_decode(v)
                    for k, v, _ in self._file["Step config"][channel_name][
                        "Relation parameters"
                    ]
                }
                relation = (
                    relation,
                    {
                        k: v
                        for k, v in rel_params.items()
                        # Preserve only the parameters useful to the relation
                        # \W is a non word character (no letter no digit)
                        if re.match(r"(.*\W+" + f"{k})|{k}" + r"(\W+.*|$)", relation)
                    },
                )
            else:
                relation = None

            steps.append(
                StepConfig(
                    name=channel_name,
                    is_ramped=is_ramped,
                    relation=relation,
                    value=None if is_ramped else step_configs[0][0][2],
                    alternate_direction=alternate,
                    ramps=[
                        (
                            RampConfig(
                                start=cfg[3],
                                stop=cfg[4],
                                steps=cfg[8],
                                new_log=bool(i == 0),
                            )
                            if cfg[0]
                            else RampConfig(
                                start=cfg[2],
                                stop=cfg[2],
                                steps=1,
                                new_log=bool(i == 0),
                            )
                        )
                        for configs in step_configs
                        for i, cfg in enumerate(configs)
                    ]
                    if is_ramped
                    else None,
                )
            )

        # Mark all channels with relation to a ramped channel as ramped.
        # One can inspect ramps to know if a step is ramped outside of a relation.
        for step in steps:
            if step.relation is not None:
                step.is_ramped |= any(
                    s.is_ramped for s in steps if s.name in step.relation[1].values()
                )

        return steps

    def get_step(self, name: str) -> StepConfig:
        """Get a step by name."""
        for step in self.steps:
            if step.name == name:
                return step
        raise ValueError(
            f"The requested step `{name}` does not exist in `{self.filename}`."
            f"Available steps are: {[step.name for step in self.steps]}"
        )

    @cached_property
    def logs(self) -> List[LogEntry]:
        """The logged channels in the Labber file.

        These are listed in the Measurement Editor under "Log channels" and in the Log
        Browser under "Active channels" with no step list.  They are listed in the
        "/Log list" HDF5 dataset with further information in the "/Data" group (for
        scalar data) and "/Traces" group (for vectorial data).
        """
        log_names = [maybe_decode(e[0]) for e in self._file["Log list"]]

        # identify scalar complex data
        complex_scalars = [
            maybe_decode(n)
            for n, v in self._file.get("Data/Channel names", ())
            if maybe_decode(v) == "Real"
        ]

        # identify vector data
        vectors = self._file.get("Traces", ())

        logs = []
        for name in log_names:
            if name in vectors:
                logs.append(
                    LogEntry(
                        name=name,
                        is_vector=True,
                        is_complex=self._file[f"Traces/{name}"].attrs.get("complex"),
                        x_name=self._file[f"Traces/{name}"].attrs.get("x, name"),
                    )
                )
            else:
                logs.append(LogEntry(name=name, is_complex=name in complex_scalars))

        return logs

    @cached_property
    def channels(self) -> List[Union[LogEntry, StepConfig]]:
        """Channels that are stepped or logged in the Labber file.

        Channels include step channels and log channels as viewed in the Measurement
        Editor under "Step sequence" and "Log channels" and are listed under "Active
        channels" in the Log Browser.

        Channel data can be retrieved using `get_data`.

        Returns
        -------
        Ramped step channels and all log channels.
        """
        return [s for s in self.steps if s.is_ramped] + [l for l in self.logs]

    @cached_property
    def channel_names(self) -> List[str]:
        """Names of the channels available in the Labber file."""
        return [c.name for c in self.channels]

    @cached_property
    def instrument_configs(self) -> List[InstrumentConfig]:
        """Instrument configurations for the measurement."""
        return [
            InstrumentConfig(group)
            for group in self._file["Instrument config"].values()
        ]

    @overload
    def get_data(
        self,
        name_or_index: Union[str, int],
        filters: Optional[dict] = None,
        filter_precision: float = 1e-10,
        get_x: Literal[False] = False,
    ) -> np.ndarray:
        pass

    @overload
    def get_data(
        self,
        name_or_index: Union[str, int],
        filters: Optional[dict] = None,
        filter_precision: float = 1e-10,
        get_x: Literal[True] = True,
    ) -> Tuple[np.ndarray, np.ndarray]:
        pass

    def get_data(
        self, name_or_index, filters=None, filter_precision=1e-10, get_x=False
    ):
        """Retrieve data base on channel name or index

        Parameters
        ----------
        name_or_index : str | int
            Name or index of the channel whose data should be loaded.

        filters : dict, optional
            Dictionary specifying channel (as str or int), value pairs to use
            to filter the returned array. For example, passing {0: 1.0} will
            ensure that only the data for which the value of channel 0 is 1.0
            are returned.

        filter_precision : float, optional
            Precision used when comparing the data to the mask value.

        get_x : bool
            Specify for vector data whether the x-data should be returned along with y-data

        Returns
        -------
        data : Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]
            Requested data (or x and requested data for vector data when x is required).
            The data are formatted such as the last axis corresponds the the most inner
            loop of the measurement and the shape match the non filtered steps of the
            measurement.

        """
        # Convert the name into a channel index.
        index = self._name_or_index_to_index(name_or_index)

        # Get the channel description that contains important details
        # (is_vector, is_complex, )
        channel = self.channels[index]  # type: ignore

        x_data: List[np.ndarray] = []
        vectorial_data = bool(isinstance(channel, LogEntry) and channel.is_vector)
        vec_dim = 0
        if vectorial_data:
            aux = self._get_traces_data(
                channel.name, is_complex=channel.is_complex, get_x=get_x  # type: ignore
            )
            # Always transpose vector to be able to properly filter them
            if not get_x:
                data = [a.T for a in aux[0]]
            else:
                x_data, data = [a.T for a in aux[0]], [a.T for a in aux[1]]
            vec_dim = data[0].shape[-1]
        else:
            data = self._get_data_data(
                channel.name, is_complex=getattr(channel, "is_complex", False)
            )

        # Filter the data based on the provided filters.
        filters = filters if filters is not None else {}

        # Only use positive indexes to describe filtering
        filters = {self._name_or_index_to_index(k): v for k, v in filters.items()}
        if filters:
            datasets = [self._file] + self._nested[:]
            results = []
            x_results = []
            for i, d in enumerate(datasets):
                masks = []
                for k, v in filters.items():
                    index = self._name_or_index_to_index(k)
                    filter_data = d["Data"]["Data"][:, index]
                    # Create the mask filtering the data
                    mask = np.less(np.abs(filter_data - v), filter_precision)
                    # If the mask is not empty, ensure we do not eliminate nans
                    # added by Labber to have complete sweeps
                    if np.any(mask):
                        mask |= np.isnan(filter_data)
                    masks.append(mask)

                mask = reduce(lambda x, y: x & y, masks)
                # For unclear reason vector data are not index in the same order
                # as other data and require the mask to be transposed before being
                # raveled
                if vectorial_data:
                    mask = np.ravel(mask.T)
                    if len(mask) > len(data[i]):
                        # if the scan was aborted, the mask (which is derived from the
                        # NaN-padded "/Data/Data" HDF5 dataset) may be larger than the
                        # vectorial data (which is derived from unpadded HDF5 datasets
                        # in "/Traces/")
                        mask = mask[: data[i].shape[0]]
                    elif len(mask) < len(data[i]):
                        # sometimes Labber inexplicably fails to record an incomplete
                        # scan in the "/Data/Data" HDF5 dataset, but still records the
                        # vectorial data points
                        logger.warning(
                            f"There are more traces ({len(data[i])}) recorded in "
                            f"'/Traces/' than there are step values ({len(mask)}) "
                            f"recorded in '/Data/Data'.  You might not be getting all "
                            f"the data from {self.filename} for {name_or_index=} with "
                            f"{filters=} and {get_x=}."
                        )
                        mask = np.append(mask, [False] * (len(data[i]) - len(mask)))

                # Filter
                results.append(data[i][mask])
                if vectorial_data and get_x:
                    x_results.append(x_data[i][mask])

                # If the filtering produces an empty output return early
                if not any(len(r) for r in results):
                    if get_x:
                        return np.empty(0), np.empty(0)
                    else:
                        return np.empty(0)

        else:
            results = data
            x_results = x_data

        # Identify the ramped steps not used for filtering
        steps_points = []
        first_step_is_used = False
        for i, s in reversed(list(enumerate(self.steps))):

            # Use ramps rather than is_ramped to consider only steps manually
            # ramped and exclude steps with multiple values because they
            # they have a relation to a ramped step but no ramp of their own
            if s.ramps is not None and i not in filters:

                # Labber stores scalar data as 3D:
                # - the first dimension is the first step number of points
                # - the second one refer to the channel
                # - the third refer to all other steps but in reverse order
                # For vector data it is the same except that the first is not special.
                if i == 0 and not vectorial_data:
                    steps_points.insert(0, s.points_per_log)
                    first_step_is_used = True
                else:
                    steps_points.append(s.points_per_log)

        # For vectorial data we add to the dimension of the vector at the end
        # of the step points. We take as many as their are logs.
        if vectorial_data:
            steps_points.append(
                (vec_dim,) * (len(steps_points[0]) if steps_points else 1)
            )

        # If we get a single value because we are accessing a value defined through
        # relations to channels we are filtering upon we can stop there and exit
        # early.
        if not steps_points:
            return results[0]  # [np.array([value])]

        # Get expected shape per log
        shape_per_logs = np.array(steps_points).T
        shaped_results = []
        shaped_x = []
        for i, shape in enumerate(shape_per_logs):
            # If the filtering produced an empty array skip it
            if results[i].shape == (0,):
                continue

            # The outer most dimension of the scan corresponds either to the first
            # index if the first step was filtered on (not first_step_is_used) or,
            # otherwise, to the second.
            points_inner_dimensions = (
                np.prod(shape[1:])
                if not first_step_is_used
                else shape[0] * np.prod(shape[2:])
            )
            padding = np.prod(results[i].shape) % (points_inner_dimensions)
            # Pad the data to ensure that only the last axis shrinks
            if padding:
                # Compute the number of points to add
                to_add = points_inner_dimensions - padding
                results[i] = np.concatenate(
                    (results[i], np.nan * np.ones(to_add)), None
                )
                if vectorial_data and get_x:
                    x_results[i] = np.concatenate(
                        (x_results[i], np.nan * np.ones(to_add)), None
                    )

            # Allow the most outer shape to shrink
            new_shape = list(shape)
            new_shape[1 if first_step_is_used and len(shape) > 1 else 0] = -1
            shaped_results.append(results[i].reshape(new_shape))
            if vectorial_data and get_x:
                shaped_x.append(x_results[i].reshape(new_shape))

        # Create the complete data
        full_data = np.concatenate(shaped_results, axis=-1)
        if vectorial_data and get_x:
            full_x = np.concatenate(shaped_x, axis=-1)

        # Move the first axis to last position so that the dimensions are in the reverse
        # order of the steps.
        if first_step_is_used:
            full_data = np.moveaxis(full_data, 0, -1)

        if vectorial_data and get_x:
            return full_x, full_data
        else:
            return full_data

    def warn_not_constant(
        self, name_or_index: Union[str, int], max_deviation: Optional[float] = None
    ):
        """Issue a warning if `name_or_index` data is not roughly constant.

        Parameters
        ----------
        name_or_index
            The name or index of the channel whose data will be checked.
        max_deviation : optional
            The largest deviation from the mean that will not issue a warning.
            If None, defaults to 1% of the mean.
        """
        data = self.get_data(name_or_index)
        mean = np.mean(data)
        if max_deviation is None:
            max_deviation = 0.01 * mean
        abs_deviation = np.abs(data - mean)
        if np.any(abs_deviation > max_deviation):
            warnings.warn(
                f"Channel `{name_or_index}` deviates from mean "
                f"by {np.max(abs_deviation)} > {max_deviation}"
            )

    def _name_or_index_to_index(self, name_or_index: Union[str, int]) -> int:
        """Provide the index of a channel from its name."""
        ch_names = self.channel_names

        if isinstance(name_or_index, str):
            if name_or_index not in ch_names:
                msg = (
                    f"The specified name ({name_or_index}) does not exist "
                    f"in the dataset. Existing names are {ch_names}"
                )
                raise ValueError(msg)
            return ch_names.index(name_or_index)
        elif name_or_index >= len(ch_names):
            msg = (
                f"The specified index ({name_or_index}) "
                f"exceeds the number of channel: {len(ch_names)}"
            )
            raise ValueError(msg)
        else:
            return name_or_index

    def _get_traces_data(
        self, channel_name: str, is_complex: bool = False, get_x: bool = False
    ) -> Tuple[List[np.ndarray], ...]:
        """Get data stored in the "/Traces/" HDF5 group of the Labber file.

        This is where vectorial data are stored."""
        if channel_name not in self._file["Traces"]:
            raise ValueError(f"Unknown traces data {channel_name}")

        x_data = []
        data = []
        # Traces dimensions are (sweep, real/imag/x, steps)
        for storage in [self._file] + self._nested:
            if is_complex:
                real = storage["Traces"][channel_name][:, 0, :]
                imag = storage["Traces"][channel_name][:, 1, :]
                data.append(real + 1j * imag)
            else:
                data.append(storage["Traces"][channel_name][:, 0, :])
            if get_x:
                x_data.append(storage["Traces"][channel_name][:, -1, :])

        if get_x:
            return x_data, data
        else:
            return (data,)

    def _get_data_data(
        self, channel_name: str, is_complex: bool = False
    ) -> List[np.ndarray]:
        """Get data stored in the "/Data/" HDF5 group of the Labber file.

        This is where ordinary (i.e. non-vectorial) data are stored."""
        names = [maybe_decode(n) for n, _ in self._file["Data"]["Channel names"]]
        if is_complex:
            re_index = names.index(channel_name)
            im_index = re_index + 1
            real = self._pull_nested_data(re_index)
            imag = self._pull_nested_data(im_index)
            return [r + 1j * i for r, i in zip(real, imag)]
        else:
            return self._pull_nested_data(names.index(channel_name))

    def _pull_nested_data(self, index: int) -> List[np.ndarray]:
        """Pull data stored in the data segmentfrom all nested logs."""
        data = [self._file["Data"]["Data"][:, index]]
        for internal in self._nested:
            data.append(internal["Data"]["Data"][:, index])
        return data

    def __enter__(self) -> "LabberData":
        """ Open the underlying HDF5 file when used as a context manager."""
        self.open()
        return self

    def __exit__(self, exc_type, exc_value, traceback) -> None:
        """Close the underlying HDF5 file when used as a context manager."""
        self.close()
示例#30
0
###### load data #####
# signal 1 path
file_folder = "/Users/shadi/PycharmProjects/plux_internship"
file_name = "BVP_RESPchest_3_2_3.h5"
#file_name = "BVP_RESPchest_4_1_4.h5"
#file_name = "BVP_RESPchest_3_2_3_4_1_4.h5"

file_path1 = file_folder + "/" + file_name

#file_path1 = "BVP_RESPchest_4_1_4.h5"

# file 1
h5_object1 = File(file_path1)
list(h5_object1.keys())

h5_group1 = h5_object1.get('00:07:80:58:9B:3F')
#print ("Second hierarchy level: " + str(list(h5_group)))
#print ("Metadata of h5_group: \n" + str(list(h5_group.attrs.keys())))
sampling_rate = h5_group1.attrs.get("sampling rate")
#print ("Sampling Rate: " + str(sampling_rate))
h5_sub_group1 = h5_group1.get("raw")
print("Third hierarchy level: " + str(list(h5_sub_group1)))
h5_data1 = h5_sub_group1.get("channel_1")
h5_data2 = h5_sub_group1.get("channel_3")
data_list1 = [item for sublist in h5_data1 for item in sublist]
miquel = data_list1[70 * sampling_rate:len(data_list1) - 10 * sampling_rate]
miquel_np = np.array(miquel)

data_list2 = [item for sublist in h5_data2 for item in sublist]
shadi = data_list2[70 * sampling_rate:len(data_list2) - 10 * sampling_rate]
shadi_np = np.array(shadi)
示例#31
0
def experimentsAbalone():
    budgets = [3,10,50,100,200,500,1000,1500,2000,3000,4000,4177]
    Sigmas = range(-10,5,1)
    Gammas = [1.0,0.9,0.8,0.7]
    Lambdas = [0.0,0.01,0.1,0.2,0.3,0.4]
    Alphas = [0.4,0.5,0.6,0.7]
    runs = 10
    acc = np.zeros((runs,),dtype=np.float64)
    exRuns = 30
    results = dict()
    f = File('../Datasets/abalone.h5','r')
    X = f.get('data')[:,:]
    lg = f.get('labels')[:]
    f.close()
    k = 3
    # Without K-means budget
    for budget in budgets:
        print "Tuning with budget size = {}".format(budget)
        finalAcc = np.zeros((exRuns,),dtype=np.float64)
        finalTime = np.zeros((exRuns,),dtype=np.float64)
        accT = np.zeros((len(Gammas),),dtype=np.float64)
        for i in xrange(len(Gammas)):
            print "Tuning Gamma {} of {}".format(i+1,len(Gammas))
            Gamma = Gammas[i]
            for j in xrange(runs):
                ok = OKMF(budget,k,50,4,Gamma,0.1,0.2,'rbf')
                try:
                    ok.fit(X)
                    lf = np.argmax(ok.H,axis=0)
                    acc[j] = accuracy(lf,lg)
                except np.linalg.LinAlgError as er:
                    print er.message
                    acc[j] = 0.0
            accT[i] = np.average(acc)
        Gamma = Gammas[np.argmax(accT)]
        accT = np.zeros((len(Lambdas),),dtype=np.float64)
        for i in xrange(len(Lambdas)):
            print "Tuning Lambda {} of {}".format(i+1,len(Lambdas))
            Lambda = Lambdas[i]
            for j in xrange(runs):
                ok = OKMF(budget,k,50,4,Gamma,Lambda,0.2,'rbf')
                try:
                    ok.fit(X)
                    lf = np.argmax(ok.H,axis=0)
                    acc[j] = accuracy(lf,lg)
                except np.linalg.LinAlgError as er:
                    print er.message
                    acc[j] = 0.0
            accT[i] = np.average(acc)
        Lambda = Lambdas[np.argmax(accT)]
        accT = np.zeros((len(Alphas),),dtype=np.float64)
        for i in xrange(len(Alphas)):
            print "Tuning Alpha {} of {}".format(i+1,len(Alphas))
            Alpha = Alphas[i]
            for j in xrange(runs):
                ok = OKMF(budget,k,50,4,Gamma,Lambda,Alpha,'rbf')
                try:
                    ok.fit(X)
                    lf = np.argmax(ok.H,axis=0)
                    acc[j] = accuracy(lf,lg)
                except np.linalg.LinAlgError as er:
                    print er.message
                    acc[j] = 0.0
            accT[i] = np.average(acc)
        Alpha = Alphas[np.argmax(accT)]
        accT = np.zeros((len(Sigmas),),dtype=np.float64)
        for i in xrange(len(Sigmas)):
            print "Tuning Sigma {} of {}".format(i+1,len(Sigmas))
            Sigma = Sigmas[i]
            sigmaR = 1.0 / ((2.0 ** Sigma) ** 2.0)
            for j in xrange(runs):
                ok = OKMF(budget,k,50,4,Gamma,Lambda,Alpha,'rbf',gamma=sigmaR)
                try:
                    ok.fit(X)
                    lf = np.argmax(ok.H,axis=0)
                    acc[j] = accuracy(lf,lg)
                except np.linalg.LinAlgError as er:
                    print er.message
                    acc[j] = 0.0
            accT[i] = np.average(acc)
        Sigma = Sigmas[np.argmax(accT)]
        sigmaR = 1.0 / ((2.0 ** Sigma) ** 2.0)
        for i in xrange(exRuns):
            ok = OKMF(budget,k,50,4,Gamma,Lambda,Alpha,'rbf',gamma=sigmaR)
            try:
                t0 = clock()
                ok.fit(X)
                lf = np.argmax(ok.H,axis=0)
                t1 = clock()
                finalAcc[i] = accuracy(lf,lg)
                finalTime[i] = t1 - t0
            except np.linalg.LinAlgError as er:
                print er.message
                finalAcc[i] = 0
                finalTime[i] = float('inf')
        results[str(budget)]=finalAcc,finalTime
    return results
def read_airs_from_wavs(wav_files,
                        framesize=None,
                        get_pow_spec=True,
                        max_air_len=None,
                        fs=None,
                        forced_fs=None,
                        keep_ids=None,
                        cacheloc='/tmp/',
                        start_at_max=True,
                        read_cached_latest=False,
                        wavform_logpow=False,
                        write_cached_latest=True,
                        max_speech_read=None,
                        max_air_read=None,
                        utt_per_env=1,
                        parse_as_dirctories=True,
                        speech_files=None,
                        save_speech_associations=True,
                        save_speech_examples=10,
                        drop_speech=False,
                        as_hdf5_ds=True,
                        choose_channel=None,
                        no_fex=False,
                        scratchpad='/tmp/',
                        copy_associations_to=None,
                        given_associations=None):
    """

    Given a set of AIR files and additional inforamtion, data for the training of DNNs for
    environment classification are prepared.

    Args:
        wav_files: Location of AIR wav files
        framesize: The framesize to ues
        get_pow_spec: Convert audio to log-power spectrum domain
        max_air_len: The maximum length of the signals (truncate to or pad to)
        fs: The sampling frequency of the wav fiels to expect
        forced_fs: The sampling frequency to convert the data to
        keep_ids: None (not used)
        cacheloc: Location to use for cache reading and saving
        start_at_max: Modify the signals so that the maximum energy sample is at the begiing. (
        can be used to align AIRs)
        read_cached_latest: Read the data from the last saved cache (if nay)
        wavform_logpow: Get the signals in the log-power time domain
        write_cached_latest: Write the collected data in a cache for fast reuse
        max_speech_read: Maximum length of speech signal to read
        max_air_read: maximum aIR length to read up to
        utt_per_env: Number of utternaces to convolve with each AIR
        parse_as_dirctories: Parse the inputs as directiries and not as individual fiels
        speech_files: Speec files of locations
        save_speech_associations: Save the speech associations with the corresponding AIRs
        save_speech_examples: Enable the saving of examples of the reverberant speech created
        drop_speech: Do not include the speech samples in the saving of the cache or in the RAM.
        Keep only the training data arrays
        as_hdf5_ds: Keep the data as HDF5 datasets on disk. (Reduces RAM usage a lot)
        choose_channel: Channels to use for each AIR
        no_fex: Skip the data processign phase and just return the raw singals
        scratchpad: Location to use for temporary saving
        copy_associations_to: Save a copy of the speech-aIR associations here
        given_associations: Provided associatiosn between speech files and AIRs. This can be used
        in the case where you want to use specific speech samples for specific AIRs

    Returns:
        (X, None), Sample_names, None,
        (AIRs, Speech, Reverberant_speech),
        (Group_name, Groups), Number_of_utternaces_convolved_with_each_AIR

    """
    try:
        from os.path import isfile, basename
    except ImportError:
        raise
    from scipy.signal import fftconvolve
    import numpy as np
    from h5py import File
    from scipy.io import wavfile
    from utils_spaudio import my_resample, write_wav
    from utils_base import find_all_ft, run_command
    from random import sample
    import pandas as pd
    from random import randint
    from time import time

    run_command('mkdir -p ' + cacheloc)
    latest_file = cacheloc + '/training_test_data_wav.h5'
    timestamp = str(time())
    filename_associations = scratchpad + '/air_speech_associations_' + timestamp + '.csv'
    base_examples_dir = scratchpad + '/feature_extraction_examples/'
    if keep_ids is not None:
        raise AssertionError('No ids exist in this context')
    if speech_files is None:
        utt_per_env = 1
        if save_speech_associations:
            print(
                'There is no speech to save in associations, setting to false')
            save_speech_associations = False
        if save_speech_examples:
            print(
                'There is no speech to save audio for, setting to 0 examples')
            save_speech_examples = 0

    try:
        hf = None
        if isfile(latest_file) and read_cached_latest:
            print('Reading :' + latest_file)
            hf = File(latest_file, 'r')
            if as_hdf5_ds:
                x = hf['x']
                ids = hf['ids']
                airs = hf['airs']
                utt_per_env = np.array(hf['utts'])
                rev_speech = hf['rev_names']
                clean_speech = hf['clean_speech']
                print('Done creating handles to : ' + latest_file)
            else:
                utt_per_env = np.array(hf['utts'])
                x = np.array(hf.get('x'))
                ids = np.array(hf.get('ids'))
                airs = np.array(hf.get('airs'))
                rev_speech = np.array(hf.get('rev_names'))
                clean_speech = np.array(hf.get('clean_speech'))
                print('Done reading : ' + latest_file)
            if given_associations is not None:
                print(
                    '! I read the cache so the given associations were not used'
                )
            if copy_associations_to is not None:
                print(
                    '! I read the cache so the associations could not be saved at '
                    + copy_associations_to)
            return (x, None), ids, None, (airs, clean_speech,
                                          rev_speech), utt_per_env
    except (ValueError, KeyError) as ME:
        print('Tried to read ' + latest_file + ' but failed with ' +
              ME.message)
        if hf is not None:
            hf.close()

    if given_associations is not None:
        print('You gave me speech associations, Speech: ' +
              str(len(given_associations['speech'])) +
              ' entries and Offsets: ' +
              str(len(given_associations['speech'])) + ' entries')

    ids = None
    x = None
    x_speech = None
    x_rev_speech = None

    if forced_fs is None:
        forced_fs = fs
    resample_op = lambda x: x
    if not forced_fs == fs:
        resample_op = lambda x: np.array(
            my_resample(np.array(x.T, dtype=float), fs, forced_fs)).T

    if max_air_read is not None:
        if fs is None:
            raise AssertionError('Cannot work with max_air_read without fs')
        max_air_read_samples = int(np.ceil(fs * max_air_read))
    if max_speech_read is not None:
        if fs is None:
            raise AssertionError('Cannot work with max_speech_read without fs')
        max_speech_read_samples = int(np.ceil(fs * max_speech_read))
    else:
        max_speech_read_samples = None

    if parse_as_dirctories:
        if not type(wav_files) is list:
            wav_files = [wav_files]
        wav_files = find_all_ft(wav_files, ft='.wav', find_iname=True)
    if speech_files is not None:
        if not type(speech_files) is list:
            speech_files = [speech_files]
        speech_files = find_all_ft(speech_files, ft='.wav', find_iname=True)

    if save_speech_examples:
        run_command('rm -r ' + base_examples_dir)
        run_command('mkdir -p ' + base_examples_dir)

    associations = []
    save_counter = 0
    all_names = [
        basename(i).replace('.wav', '') + '_' + str(j) for i in wav_files
        for j in range(utt_per_env)
    ]
    if type(choose_channel) is list:
        choose_channel = [
            i for i in choose_channel for _ in range(utt_per_env)
        ]
    wav_files = [i for i in wav_files for _ in range(utt_per_env)]
    offsets = []
    for i, this_wav_file in enumerate(wav_files):
        if False and speech_files is not None:
            print "Reading: " + this_wav_file + " @ " + str(
                i + 1) + " of " + str(len(wav_files)),
        names = [all_names[i]]
        this_fs, airs = wavfile.read(this_wav_file)
        airs = airs.astype(float)
        if airs.ndim > 1:
            if choose_channel is not None:
                if type(choose_channel) is list:
                    airs = airs[:, choose_channel[i]]
                    names[0] += '_ch' + str(choose_channel[i])
                else:
                    airs = airs[:, choose_channel]
                    names[0] += '_ch' + str(choose_channel)
            else:
                names = [
                    names[0] + '_' + str(ch_id)
                    for ch_id in range(airs.shape[1])
                ]
            airs = airs.T
        airs = np.atleast_2d(airs)
        airs /= np.repeat(np.atleast_2d(abs(airs).max()).T, airs.shape[1],
                          1).astype(float)
        if airs.shape[0] > 1 and given_associations is not None:
            raise AssertionError(
                'Cannot work out given associations for multichannel airs')
        this_speech_all = []
        this_rev_speech_all = []
        if speech_files is not None:
            for ch_id in range(airs.shape[0]):
                if given_associations is None:
                    chosen_file = sample(range(len(speech_files)), 1)[0]
                    this_speech_file = speech_files[chosen_file]
                else:
                    chosen_file = given_associations['speech'][i]
                    this_speech_file = chosen_file
                associations.append(chosen_file)
                this_speech_fs, this_speech = wavfile.read(this_speech_file)
                if this_speech.ndim > 1:
                    raise AssertionError(
                        'Can\'t deal with multichannel speech in this context')
                if not this_speech_fs == this_fs:
                    this_speech = my_resample(this_speech, this_speech_fs,
                                              this_fs)
                max_offset_for_check = None
                if max_speech_read_samples is not None:
                    max_offset_for_check = this_speech.size - max_speech_read_samples
                    offset = randint(
                        0, this_speech.size - max_speech_read_samples)
                    this_speech = this_speech[offset:offset +
                                              max_speech_read_samples]
                else:
                    offset = 0
                if given_associations is not None:
                    offset = given_associations['offsets'][i]
                    if max_speech_read_samples is not None:
                        if offset >= max_offset_for_check:
                            raise AssertionError(
                                'Invalid offset from given associations, got '
                                + str(offset) + ' expected max is ' +
                                str(this_speech.size -
                                    max_speech_read_samples))

                conv_air = np.array(airs[ch_id, :])
                conv_air = conv_air[np.where(~(conv_air == 0))[-1][0]:np.
                                    where(~(conv_air == 0))[-1][-1]]

                # Making convolution
                this_rev_speech = fftconvolve(this_speech, conv_air, 'same')
                #

                dp_arival = np.argmax(abs(conv_air))
                this_rev_speech = this_rev_speech[dp_arival:]
                if dp_arival > 0:
                    this_rev_speech = np.concatenate(
                        (this_rev_speech,
                         np.zeros(dp_arival, dtype=this_rev_speech.dtype)))

                this_speech = np.atleast_2d(this_speech)
                this_rev_speech = np.atleast_2d(this_rev_speech)
                this_speech_all.append(this_speech)
                this_rev_speech_all.append(this_rev_speech)

                offsets.append(offset)
                if save_speech_examples >= save_counter:
                    save_names = [
                        basename(this_wav_file).replace('.wav', '') + '_air_' +
                        str(offset) + '.wav',
                        basename(this_wav_file).replace('.wav', '') +
                        '_clean_speech_' + str(offset) + '.wav',
                        basename(this_wav_file).replace('.wav', '') +
                        '_rev_speech_' + str(offset) + '.wav'
                    ]
                    for examples in range(len(save_names)):
                        save_names[examples] = base_examples_dir + save_names[
                            examples]
                    write_wav(save_names[0], this_fs, airs[ch_id, :])
                    write_wav(save_names[1], this_fs, this_speech.flatten())
                    write_wav(save_names[2], this_fs,
                              this_rev_speech.flatten())
                    save_counter += 1
            this_speech = np.concatenate(this_speech_all, axis=0)
            this_rev_speech = np.concatenate(this_rev_speech_all, axis=0)

        if not this_fs == fs:
            raise AssertionError('Your sampling rates are not consistent')
        if i > 0:
            ids = np.concatenate((ids, names))
        else:
            ids = names

        if max_air_read is not None:
            airs = airs[:, 0:max_air_read_samples]
        if False and speech_files is not None:
            print("Got " + str(airs.shape))
        airs = resample_op(airs)
        if airs.ndim < 2:
            airs = np.atleast_2d(airs)
        # print('Done resampling')
        if i > 0:
            if x.shape[1] < airs.shape[1]:
                npads = -x.shape[1] + airs.shape[1]
                x = np.concatenate((x, np.zeros(
                    (x.shape[0], npads)).astype(x.dtype)),
                                   axis=1)
                x = np.concatenate((x, airs), axis=0)
            else:
                if x.shape[1] > airs.shape[1]:
                    npads = x.shape[1] - airs.shape[1]
                    airs = np.concatenate(
                        (airs, np.zeros(
                            (airs.shape[0], npads)).astype(airs.dtype)),
                        axis=1)
                x.resize((x.shape[0] + airs.shape[0], x.shape[1]),
                         refcheck=False)
                x[-airs.shape[0]:, :] = np.array(airs)

            if speech_files is not None:
                if x_speech.shape[1] < this_speech.shape[1]:
                    npads = -x_speech.shape[1] + this_speech.shape[1]
                    x_speech = np.concatenate(
                        (x_speech, np.zeros((x_speech.shape[0], npads)).astype(
                            x_speech.dtype)),
                        axis=1)
                    x_speech = np.concatenate((x_speech, this_speech), axis=0)
                else:
                    if x_speech.shape[1] > this_speech.shape[1]:
                        npads = x_speech.shape[1] - this_speech.shape[1]
                        this_speech = np.concatenate(
                            (this_speech,
                             np.zeros((this_speech.shape[0], npads)).astype(
                                 this_speech.dtype)),
                            axis=1)
                    x_speech.resize((x_speech.shape[0] + this_speech.shape[0],
                                     x_speech.shape[1]),
                                    refcheck=False)
                    x_speech[-this_speech.shape[0]:, :] = this_speech

                if x_rev_speech.shape[1] < this_rev_speech.shape[1]:
                    npads = -x_rev_speech.shape[1] + this_rev_speech.shape[1]
                    x_rev_speech = np.concatenate(
                        (x_rev_speech, np.zeros(
                            (x_rev_speech.shape[0], npads)).astype(
                                x_rev_speech.dtype)),
                        axis=1)
                    x_rev_speech = np.concatenate(
                        (x_rev_speech, this_rev_speech), axis=0)
                else:
                    if x_rev_speech.shape[1] > this_rev_speech.shape[1]:
                        npads = x_rev_speech.shape[1] - this_rev_speech.shape[1]
                        this_rev_speech = np.concatenate(
                            (this_rev_speech,
                             np.zeros(
                                 (this_rev_speech.shape[0], npads)).astype(
                                     this_rev_speech.dtype)),
                            axis=1)
                    x_rev_speech.resize(
                        (x_rev_speech.shape[0] + this_rev_speech.shape[0],
                         x_rev_speech.shape[1]),
                        refcheck=False)
                    x_rev_speech[
                        -this_rev_speech.shape[0]:, :] = this_rev_speech
        else:
            x = np.array(airs)
            if speech_files is not None:
                x_speech = np.array(this_speech)
                x_rev_speech = np.array(this_rev_speech)

    if save_speech_associations:
        from utils_base import run_command
        df = pd.DataFrame({
            'air':
            wav_files,
            'speech':
            np.array(speech_files)[associations]
            if given_associations is None else given_associations['speech'],
            'offsets':
            offsets
            if given_associations is None else given_associations['offsets']
        })

        df.to_csv(filename_associations, index=False)
        print('Saved: ' + filename_associations +
              ('' if given_associations is None else
               ' which was created from the given associations'))
        if copy_associations_to is not None:
            run_command('cp ' + filename_associations + ' ' +
                        copy_associations_to)
            print('Saved: ' + copy_associations_to)

    if fs is not None:
        print('Got ' + str(x.shape[0]) + ' AIRs of duration ' +
              str(x.shape[1] / float(fs)))
    else:
        print('Got ' + str(x.shape[0]) + ' AIRs of length ' + str(x.shape[1]))

    if speech_files is not None:
        proc_data = x_rev_speech
    else:
        proc_data = x

    if drop_speech:
        x_rev_speech = []
        x_speech = []
        x = []

    if no_fex:
        x_out = None
        print('Skipping feature extraction')
    else:
        x_out = data_post_proc(np.array(proc_data), forced_fs, start_at_max,
                               framesize, get_pow_spec, max_air_len,
                               wavform_logpow)

        print('Left with ' + str(x_out.shape) + ' AIR features data ')

    ids = ids.astype(str)

    wrote_h5 = False
    if write_cached_latest:
        try:
            hf = File(latest_file, 'w')
            if no_fex:
                hf.create_dataset('x', data=[])
            else:
                hf.create_dataset('x', data=x_out)
            hf.create_dataset('y', data=[])
            hf.create_dataset('ids', data=ids)
            hf.create_dataset('class_names', data=[])
            hf.create_dataset('airs', data=x)
            hf.create_dataset('utts', data=utt_per_env)
            if speech_files is not None:
                hf.create_dataset('clean_speech', data=x_speech)
                hf.create_dataset('rev_names', data=x_rev_speech)
            else:
                hf.create_dataset('clean_speech', data=[])
                hf.create_dataset('rev_names', data=[])
            hf.close()
            wrote_h5 = True
            print('Wrote: ' + str(latest_file))
        except IOError as ME:
            print('Cache writing failed with ' + str(ME.message))

        if (not wrote_h5) and as_hdf5_ds:
            raise AssertionError('Could not provide data in correct format')
        if as_hdf5_ds:
            hf = File(latest_file, 'r')
            x_out = hf['x']
            ids = hf['ids']
            x = hf['airs']
            x_speech = hf['clean_speech']
            x_rev_speech = hf['rev_names']
            # hf.close()

    return (x_out, None), ids, None, (x, x_speech, x_rev_speech), utt_per_env
def get_ace_xy(h5_file='../results_dir/ace_h5_info.h5',
               ace_base='../Local_Databases/AIR/ACE/',
               y_type='room',
               group_by=None,
               utt_per_env=1,
               speech_files=None,
               print_distributions=False,
               parse_as_dirctories=False,
               choose_channel=None,
               **kwargs):
    """

    Collects training data and labels for traiing of DNNs using ace_discriminative_nets,
    based on the ACE Challenge data[1].

    Args:
        h5_file: Location of HFD5 dataset file for the ACE database, which is provided with this
        repository at Code/results_dir/ace_h5_info.h5. Contains information about the filenames,
        number of channels and also ground truth acoustic parameter values. If you want to create a
        new one, then use fe_utils.compile_ace_h5
        ace_base: The location of the ACE database data
        y_type: Creating labels from the data using specific information. This
        can be either of:
             'room', 'recording', 'array', 'recording', 'position', 'air'
        group_by: Creating grouping information from the data using specific information. This
        can be either of:
             'room', 'recording', 'array', 'recording', 'position', 'air'
        utt_per_env: Number of speech utterances to convolve with each AIR
        speech_files: Speech directory to pick up speech from and convolve it with the AIRs
        print_distributions: Print data information with regards to class distributions
        parse_as_dirctories: (ignored)
        choose_channel: (ignored)
        **kwargs: Additional arguments to be passed to read_airs_from_wavs

    Returns:
        (X, Y), Sample_names, Class_names,
        (AIRs, Speech, Reverberant_speech),
        (Group_name, Groups)

    """
    from h5py import File
    import numpy as np
    try:
        from os.path import basename
    except ImportError:
        raise
    from utils_base import flatten_list
    parse_as_dirctories = False

    hf = File(h5_file, 'r')
    wav_files = (np.array(hf.get('filenames')).astype(str)).tolist()
    chan = (np.array(hf.get('chan')).astype(int) - 1).tolist()

    type_dict = {
        '502': 'Office',
        '803': 'Office',
        '503': 'Meeting_Room',
        '611': 'Meeting_Room',
        '403a': 'Lecture_Room',
        '508': 'Lecture_Room',
        'EE-lobby': 'Building_Lobby'
    }
    basenames = [
        thename.split('/')[-1].replace('EE_lobby', 'EE-lobby')
        for thename in wav_files
    ]
    room = [thename.split('_')[1] for thename in basenames]
    array = [thename.split('_')[0] for thename in basenames]
    room_type = [type_dict[thename.split('_')[1]] for thename in basenames]
    recording = basenames

    if ace_base is None:
        x_out = None
        x = None
        x_speech = None
        x_rev_speech = None
        ids = flatten_list([[
            basename(this_file).replace('.wav', '') + '_' + str(j) + '_ch' +
            str(k) for k in range(chan[i])
        ] for i, this_file in enumerate(wav_files)
                            for j in range(utt_per_env)])
    else:
        for i in range(len(wav_files)):
            wav_files[i] = ace_base + '/' + wav_files[i]
        (x_out, _), \
        ids, _, \
        (x, x_speech, x_rev_speech), \
        utt_per_env = read_airs_from_wavs(
            wav_files, utt_per_env=utt_per_env, speech_files=speech_files,
            parse_as_dirctories=parse_as_dirctories,
            choose_channel=chan,
            **kwargs)
    if 'ch' not in ids[0]:
        if np.sum(['ch' in ids[i] for i in range(len(ids))]) > 0:
            raise AssertionError('Unexpected condition')
        ch = [0 for _ in range(len(ids))]
    else:
        ch = [int(i.split('ch')[1]) for i in ids]

    y = []
    class_names = []

    flat_back_y = False
    if not (isinstance(y_type, list) or isinstance(y_type, tuple)):
        flat_back_y = True
        y_type = (y_type, )

    for this_y_type in y_type:
        if this_y_type == 'room':
            new_y, new_class_names, _ = categorical_to_mat(room)
            def_group_by = 'recording'
        elif this_y_type == 'type':
            new_y, new_class_names, _ = categorical_to_mat(room_type)
            def_group_by = 'room'
        elif this_y_type == 'array':
            new_y, new_class_names, _ = categorical_to_mat(array)
            def_group_by = 'recording'
        elif this_y_type == 'position' or y_type == 'position':
            new_y, new_class_names, _ = categorical_to_mat(recording)
            def_group_by = 'air'
        elif this_y_type == 'channel':
            new_y, new_class_names, _ = categorical_to_mat(ch)
            def_group_by = 'position'
        else:
            raise AssertionError('Invalid y_type')
        y.append(new_y)
        class_names.append(new_class_names)

    flat_back_groups = False
    if group_by is None:
        group_by = (def_group_by, )
    elif not (isinstance(group_by, list) or isinstance(group_by, tuple)):
        flat_back_groups = True
        group_by = (group_by, )

    group_name, groups = ([], [])
    for this_group_by in group_by:
        if this_group_by == 'recording' or this_group_by == 'position':
            _, new_group_name, new_groups = categorical_to_mat(recording)
        elif this_group_by == 'room':
            _, new_group_name, new_groups = categorical_to_mat(room)
        elif this_group_by == 'array':
            _, new_group_name, new_groups = categorical_to_mat(array)
        elif this_group_by == 'air':
            new_groups = np.atleast_2d(np.arange(y.shape[0])).T
            new_group_name = np.array(ids)
        elif this_group_by == 'channel':
            max_ch = max(ch) + 1
            ch_array = np.zeros((len(ch), max_ch), dtype=bool)
            for i in range(len(ch)):
                ch_array[i, ch[i]] = True
            new_groups = np.array(ch_array)
            new_group_name = np.array(['ch_' + str(i) for i in range(max_ch)])
        else:
            raise AssertionError('Invalid group_by ' + this_group_by)
        group_name.append(new_group_name)
        groups.append(new_groups)

    for i in range(len(y)):
        if print_distributions:
            print_split_report(y[i])
        if np.any(~(np.sum(y[i], axis=1) == 1)):
            raise AssertionError('Invalid y outputs')
        y[i] = np.concatenate([
            y[i][ii:ii + 1, :] for ii in range(y[i].shape[0])
            for _ in range(utt_per_env)
        ],
                              axis=0)
    for ii in range(len(groups)):
        groups[ii] = [
            np.concatenate([
                list(range(i * utt_per_env, (i + 1) * utt_per_env))
                for i in groups[ii][j]
            ]).astype(int) for j in range(len(groups[ii]))
        ]

    y = tuple(y)
    class_names = tuple(class_names)
    groups = tuple(groups)
    group_name = tuple(group_name)
    if flat_back_groups:
        groups = groups[0]
        group_name = group_name[0]
    if flat_back_y:
        y = y[0]
        class_names = class_names[0]

    return (x_out, y), ids, class_names, (x, x_speech,
                                          x_rev_speech), (group_name, groups)