Ejemplo n.º 1
0
    def segment(self, exp_file, stim_file, output_file):

        exp = Experiment.load(exp_file, stim_file)
        spec_colormap()

        all_stim_ids = list()
        for ekey in list(exp.epoch_table.keys()):
            etable = exp.epoch_table[ekey]
            stim_ids = etable['id'].unique()
            all_stim_ids.extend(stim_ids)
        stim_ids = np.unique(all_stim_ids)

        # read all the stims that are already segmented
        finished_stims = list()
        if os.path.exists(output_file):
            f = open(output_file, 'r')
            lns = f.readlines()
            f.close()
            finished_stims = [int(x.split(',')[0]) for x in lns if len(x) > 0]
            print('finished_stims=', finished_stims)

        # manually segment the stimuli
        print('# of stims: %d' % len(stim_ids))
        with open(output_file, 'a') as ofd:
            for stim_id in stim_ids:
                if stim_id in finished_stims:
                    continue
                stimes = self.show_stim(exp, stim_id)
                ofd.write('{},{}\n'.format(
                    stim_id, ','.join(['%f' % s for s in stimes])))
                ofd.flush()
Ejemplo n.º 2
0
    def plot_all_slices(self, slice_len=3.0, output_dir=None):

        # zcsore the data
        self.X -= self.X.mean(axis=0)
        self.X /= self.X.std(axis=0, ddof=1)
        absmax = np.percentile(np.abs(self.X), 99)

        # generate random colors
        C = np.random.rand(len(self.components), 3)

        # rescale random colors so they're light on a black background
        C *= 0.75
        C += 0.25

        print('Loading experiment...')
        exp = Experiment.load(self.experiment_file,
                              self.stimulus_file,
                              read_only_stims=True)
        seg = exp.get_segment(self.block_name, self.segment_name)

        slices = np.arange(0, self.duration, slice_len)
        for stime, etime in zip(slices[:-1], slices[1:]):
            print('Plotting slice from %0.6fs to %0.6fs' % (stime, etime))
            self.plot_slice(stime,
                            etime,
                            exp=exp,
                            seg=seg,
                            colors=C,
                            absmax=absmax,
                            output_dir=output_dir)
            plt.close('all')
Ejemplo n.º 3
0
    def load(clz, output_file, load_real=True, load_complex=True):

        hf = h5py.File(output_file, 'r')

        mt = MEMDTransform()

        mt.num_samples = hf.attrs['num_samples']
        mt.num_noise_channels = hf.attrs['num_noise_channels']
        mt.num_bands = hf.attrs['num_bands']
        mt.resolution = hf.attrs['resolution']
        mt.max_iterations = hf.attrs['max_iterations']

        mt.index2electrode = hf.attrs['index2electrode']
        mt.sample_rate = hf.attrs['sample_rate']
        mt.rcg_names = hf.attrs['rcg_names']

        mt.start_time = hf.attrs['start_time']
        mt.end_time = hf.attrs['end_time']

        mt.X = None
        mt.Z = None
        mt.file_name = output_file

        # load up Experiment object
        stim_file_name = hf.attrs['stimulus_file']
        exp_file_name = hf.attrs['experiment_file']
        mt.experiment = Experiment.load(exp_file_name, stim_file_name)

        # get the right segment
        seg_name = hf.attrs['segment_name']
        mt.block_name = hf.attrs['block_name']
        mt.segment = mt.experiment.get_segment(mt.block_name, seg_name)

        if load_real:
            mt.X = np.array(hf['bands'])
        if load_complex:
            mt.Z = np.array(hf['bands_analytic'])

        hf.close()
        return mt
Ejemplo n.º 4
0
    def load(clz, output_file, load_experiment=False):
        se = SpikeEventTransform()
        hf = h5py.File(output_file, 'r')
        se.experiment_file = hf.attrs['experiment_file']
        se.stim_file = hf.attrs['stim_file']
        se.bin_size = hf.attrs['bin_size']
        se.sort_code = hf.attrs['sort_code']
        se.rcg_names = hf.attrs['rcg_names']
        for seg_name in list(hf.keys()):
            se.envelopes[seg_name] = np.array(hf[seg_name]['envelope'])
            se.events[seg_name] = np.array(hf[seg_name]['events'])
            se.spec_envelopes[seg_name] = np.array(
                hf[seg_name]['stim_envelope'])
            edata = dict()
            egrp = hf[seg_name]['epoch_table']
            for key in list(egrp.keys()):
                edata[key] = np.array(egrp[key])
            se.epoch_tables[seg_name] = edata
        hf.close()

        if load_experiment:
            se.experiment = Experiment.load(se.experiment_file, se.stim_file)

        return se
Ejemplo n.º 5
0
def compare_stims(exp_file, stim_file, seg_file, bs_file):
    exp = Experiment.load(exp_file, stim_file)
    spec_colormap()

    all_stim_ids = list()
    for ekey in list(exp.epoch_table.keys()):
        etable = exp.epoch_table[ekey]
        stim_ids = etable['id'].unique()
        all_stim_ids.extend(stim_ids)
    stim_ids = np.unique(all_stim_ids)

    # read the manual segmentation data
    man_segs = dict()
    with open(seg_file, 'r') as f:
        lns = f.readlines()
        for ln in lns:
            x = ln.split(",")
            stim_id = int(x[0])
            stimes = [float(f) for f in x[1:]]
            assert len(
                stimes) % 2 == 0, "Uneven # of syllables for stim %d" % stim_id
            ns = len(stimes) / 2
            man_segs[stim_id] = np.array(stimes).reshape([ns, 2])

    # get the automated segmentation
    algo_segs = dict()
    bst = BiosoundTransform.load(bs_file)
    for stim_id in stim_ids:
        i = bst.stim_df.stim_id == stim_id
        d = list(zip(bst.stim_df[i].start_time, bst.stim_df[i].end_time))
        d.sort(key=operator.itemgetter(0))
        algo_segs[stim_id] = np.array(d)

    for stim_id in stim_ids:

        # get the raw sound pressure waveform
        wave = exp.sound_manager.reconstruct(stim_id)
        wave_sr = wave.samplerate
        wave = np.array(wave).squeeze()
        wave_t = np.arange(len(wave)) / wave_sr

        # compute the amplitude envelope
        amp_env = temporal_envelope(wave, wave_sr, cutoff_freq=200.0)
        amp_env /= amp_env.max()

        # compute the spectrogram
        spec_sr = 1000.
        spec_t, spec_freq, spec, spec_rms = gaussian_stft(wave,
                                                          float(wave_sr),
                                                          0.007,
                                                          1. / spec_sr,
                                                          min_freq=300.,
                                                          max_freq=8000.)
        spec = np.abs(spec)**2
        log_transform(spec, dbnoise=70)

        figsize = (23, 12)
        plt.figure(figsize=figsize)

        ax = plt.subplot(2, 1, 1)
        plot_spectrogram(spec_t,
                         spec_freq,
                         spec,
                         ax=ax,
                         colorbar=False,
                         fmin=300.,
                         fmax=8000.,
                         colormap='SpectroColorMap')
        for k, (stime, etime) in enumerate(algo_segs[stim_id]):
            plt.axvline(stime, c='k', linewidth=2.0, alpha=0.8)
            plt.axvline(etime, c='k', linewidth=2.0, alpha=0.8)
            plt.text(stime, 7000., str(k + 1), fontsize=14)
            plt.text(etime, 7000., str(k + 1), fontsize=14)
        plt.title('Algorithm Segmentation')
        plt.axis('tight')

        ax = plt.subplot(2, 1, 2)
        plot_spectrogram(spec_t,
                         spec_freq,
                         spec,
                         ax=ax,
                         colorbar=False,
                         fmin=300.,
                         fmax=8000.,
                         colormap='SpectroColorMap')
        for k, (stime, etime) in enumerate(man_segs[stim_id]):
            plt.axvline(stime, c='k', linewidth=2.0, alpha=0.8)
            plt.axvline(etime, c='k', linewidth=2.0, alpha=0.8)
            plt.text(stime, 7000., str(k + 1), fontsize=14)
            plt.text(etime, 7000., str(k + 1), fontsize=14)
        plt.title('Manual Segmentation')
        plt.axis('tight')

        plt.show()
Ejemplo n.º 6
0
    hf2.close()

    #set the data properties
    for prop_name in props_to_merge:
        hf[prop_name] = merged_arrays[prop_name]

    hf.close()


if __name__ == '__main__':

    exp_name = 'YelBlu6903F'
    exp_file = '/auto/tdrive/mschachter/data/%s/%s.h5' % (exp_name, exp_name)
    stim_file = '/auto/tdrive/mschachter/data/%s/stims.h5' % exp_name
    output_dir = '/auto/tdrive/mschachter/data/%s/transforms' % exp_name
    exp = Experiment.load(exp_file, stim_file)

    #start_time = 1280.0
    #end_time = 1300.0
    start_time = None
    end_time = None
    hemis = ['R']

    block_name = 'Site3'
    segment_name = 'Call3'

    segment = exp.get_segment(block_name, segment_name)
    seg_uname = segment_to_unique_name(segment)

    if start_time is not None and end_time is not None:
        ofname = 'MEMD_%s_%s_%0.3f_%0.3f.h5' % (seg_uname, ','.join(hemis),
Ejemplo n.º 7
0
            plt.savefig(ofile, facecolor=fig.get_facecolor(), edgecolor='none')
        else:
            plt.show()


if __name__ == '__main__':

    exp_name = 'GreBlu9508M'
    exp_file = '/auto/tdrive/mschachter/data/GreBlu9508M/%s.h5' % exp_name
    stim_file = '/auto/tdrive/mschachter/data/GreBlu9508M/stims.h5'
    output_dir = '/auto/tdrive/mschachter/data/GreBlu9508M/transforms'

    block_name = 'Site4'
    segment_name = 'Call1'

    exp = Experiment.load(exp_file, stim_file, read_only_stims=True)
    segment = exp.get_segment(block_name, segment_name)
    seg_uname = segment_to_unique_name(segment)

    ofname = 'ICA_%s_L.h5' % seg_uname
    ofile = os.path.join(output_dir, ofname)

    icat = ICALFPTransform()
    icat.transform(exp, segment, electrodes=np.arange(16) + 1)
    icat.save(ofile)

    # icat = ICALFPTransform.load(ofile)
    # icat.plot()
    # icat.plot_all_slices(output_dir='/auto/tdrive/mschachter/figures/ica/GreBlu9508M/L')
    # plt.show()