def compare_lfp(load_loc, out_loc):
    ndata1 = NData()
    ndata2 = NData()
    grid = np.meshgrid(np.arange(32), np.arange(32), indexing='ij')
    stacked = np.stack(grid, 2)
    pairs = stacked.reshape(-1, 2)
    result_a = np.zeros(shape=pairs.shape[0], dtype=np.float32)

    for i, pair in enumerate(pairs):
        load_lfp(load_loc, pair[0], ndata1)
        load_lfp(load_loc, pair[1], ndata2)
        res = get_normalised_diff(ndata1.lfp.get_samples(),
                                  ndata2.lfp.get_samples())
        result_a[i] = res

    with open(out_loc, "w") as f:
        headers = [str(i) for i in range(1, 33)]
        out_str = ",".join(headers)
        f.write(out_str)
        out_str = ""
        for i, (pair, val) in enumerate(zip(pairs, result_a)):
            if i % 32 == 0:
                f.write(out_str + "\n")
                out_str = ""

            out_str += "{:.2f},".format(val)
            # f.write("({}, {}): {:.2f}\n".format(pair[0], pair[1], val))
        f.write(out_str + "\n")

    return result_a
Beispiel #2
0
 def list_all_units(self):
     """Print all the units in the container."""
     if self._load_on_fly:
         for key, vals in self.get_file_dict().items():
             if key == "Spike":
                 for descriptor in vals:
                     result = NData()
                     self._load(key, descriptor, ndata=result)
                     print("units are {}".format(result.get_unit_list()))
     else:
         for data in self._container:
             print("units are {}".format(data.get_unit_list()))
Beispiel #3
0
    def get_data(self, index=None):
        """
        Return the NData objects in this collection, or a specific object.

        Do not call this with no index if loading data on the fly.

        Parameters
        ----------
        index : int
            Optional index to get data at
            Defaults to None, in which case all data is returned.

        Returns
        -------
        NData or list of NData objects

        """
        if self._load_on_fly:
            if index is None:
                logging.error("Can't load all data when loading on the fly")
            result = NData()
            for key, vals in self.get_file_dict().items():
                descriptor = vals[index]
                self._load(key, descriptor, ndata=result)
            return result
        if index is None:
            return self._container
        if index >= self.get_num_data():
            logging.error("Input index to get_data out of range")
            return
        return self._container[index]
Beispiel #4
0
    def _load_all_data(self):
        """Intended private function which loads all the data."""
        if self._load_on_fly:
            logging.error(
                "Don't load all the data in container if loading on the fly")
        for key, vals in self.get_file_dict().items():
            for idx, _ in enumerate(vals):
                if idx >= self.get_num_data():
                    self.add_data(NData())

            for idx, descriptor in enumerate(vals):
                self._load(key, descriptor, idx=idx)
Beispiel #5
0
    def get_data_at(self, data_index, unit_index):
        """
        Return an NData object from the given indices.

        Parameters
        ----------
        data_index : int
            The index in the container to return data at.
        unit_index : int
            The unit number to set on the returned data.

        Returns
        -------
        NData
            The ndata object at data_index with unit number unit_index

        """
        if self._load_on_fly:
            try:
                if (data_index == self._last_data_pt[0]
                        and (self._last_data_pt[1] is not None)):
                    result = self._last_data_pt[1]
                else:
                    result = NData()
                    for key, vals in self.get_file_dict().items():
                        if key == "STM":
                            continue
                        descriptor = vals[data_index]
                        self._load(key,
                                   descriptor,
                                   idx=data_index,
                                   ndata=result)
                    self._last_data_pt = (data_index, result)
            except Exception as e:
                log_exception(e, "During loading data")
        else:
            result = self.get_data(data_index)
        if len(self.get_units()) > 0:
            result.set_unit_no(self.get_units(data_index)[unit_index])
        return result
Beispiel #6
0
    def set_units(self, units='all'):
        """
        Set the list of units for the collection.

        Parameters
        ----------
        units : list or str:
            If a list, indicates the units to use for each data object
            stored in the collection. "all" is accepted as a single element.
            Otherwise, a string "all" is expected, which sets all available
            units picked up from clustering files.
        """
        self._units = []
        if self.get_file_dict() == {}:
            raise ValueError("Can't set units for empty collection")
        if units == 'all':
            if self._load_on_fly:
                vals = self.get_file_dict()["Spike"]
                for descriptor in vals:
                    result = NData()
                    self._load("Spike", descriptor, ndata=result)
                    self._units.append(result.get_unit_list())
            else:
                for data in self.get_data():
                    self._units.append(data.get_unit_list())

        elif isinstance(units, list):
            for idx, unit in enumerate(units):
                if unit == 'all':
                    if self._load_on_fly:
                        vals = self.get_file_dict()["Spike"]
                        descriptor = vals[idx]
                        result = NData()
                        self._load("Spike", descriptor, ndata=result)
                        all_units = result.get_unit_list()
                    else:
                        all_units = self.get_data(idx).get_unit_list()
                    self._units.append(all_units)
                elif isinstance(unit, int):
                    self._units.append([unit])
                elif isinstance(unit, list):
                    self._units.append(unit)
                else:
                    logging.error(
                        "Unrecognised type {} passed to set units".format(
                            type(unit)))

        else:
            logging.error("Unrecognised type {} passed to set units".format(
                type(units)))
        self._unit_count = self._count_num_units()
Beispiel #7
0
def main(url, file_name, verbose=False):
    """
    This demonstrates example hdf5 usage.

    url should be the download url of a hdf5 file.
    file_name is a local disk path to store that file in.
    if verbose is true, information about the hdf5 file is printed.

    """
    # Fetch a file from OSF if not available on disk
    if not os.path.exists(file_name):
        print("Downloading file from {} to {}".format(url, file_name))
        urllib.request.urlretrieve(url, file_name)
    else:
        print("Using {}".format(file_name))

    if verbose:
        from skm_pyutils.py_print import print_h5
        print_h5(file_name)

    # Set up the h5 paths
    spike_path = "/processing/Shank/7"
    pos_path = "/processing/Behavioural/Position"
    lfp_path = "/processing/Neural Continuous/LFP/eeg"

    # HDF requires filename + path_in_hdf5
    # This function just does that
    def to_hdf_path(x):
        return file_name + "+" + x

    # Load in that data
    ndata = NData()
    ndata.set_data_format("NWB")
    ndata.set_spatial_file(to_hdf_path(pos_path))
    ndata.set_spike_file(to_hdf_path(spike_path))
    ndata.set_lfp_file(to_hdf_path(lfp_path))
    ndata.load()

    # Choose the unit number from those available
    print("Units are:", ndata.get_unit_list())
    unit_no = int(input("Unit to use:\n").strip())
    ndata.set_unit_no(unit_no)
    print("Loaded:", ndata)

    # Perform analysis
    ndata.place()
    ndata.wave_property()
    # print(ndata.get_results()["Spatial Skaggs"])
    # print(ndata.get_results()["Mean Spiking Freq"])
    print(ndata.get_results(spaces_to_underscores=True))
    print(ndata.get_results(spaces_to_underscores=False))
def load_axona_data():
    dir = r'C:\Users\smartin5\Recordings\recording_example'
    spike_file = os.path.join(dir, "010416b-LS3-50Hz10V5ms.2")
    pos_file = os.path.join(dir, "010416b-LS3-50Hz10V5ms_2.txt")
    lfp_file = os.path.join(dir, "010416b-LS3-50Hz10V5ms.eeg")
    unit_no = 7
    ndata = NData()
    ndata.set_spike_file(spike_file)
    ndata.set_spatial_file(pos_file)
    ndata.set_lfp_file(lfp_file)
    ndata.load()
    ndata.set_unit_no(unit_no)

    return ndata
def load_h5_data():
    data_dir = r'C:\Users\smartin5\Recordings\NC_eg'
    main_file = "040513_1.hdf5"
    spike_file = "/processing/Shank/6"
    pos_file = "/processing/Behavioural/Position"
    lfp_file = "/processing/Neural Continuous/LFP/eeg"
    unit_no = 3

    def m_file(x): return os.path.join(data_dir, main_file + "+" + x)
    ndata = NData()
    ndata.set_data_format(data_format='NWB')
    ndata.set_spatial_file(m_file(pos_file))
    ndata.set_spike_file(m_file(spike_file))
    ndata.set_lfp_file(m_file(lfp_file))
    ndata.load()
    ndata.set_unit_no(unit_no)

    return ndata
Beispiel #10
0
def compare_lfp(fname, out_base_dir=None, ch=16):
    '''
    Parameters
    ----------
    fname : str
        full path name without extension
    out_base_dir : str, None
        Path for desired output location. Default - Saves output to folder named !LFP in base directory.
    ch: int
        Number of LFP channels in session
    '''
    if out_base_dir == None:
        out_base_dir = os.path.join(os.path.dirname(fname), r"!LFP")
        make_dir_if_not_exists(out_base_dir)
    load_loc = fname + ".eeg"
    out_name = os.path.basename(fname) + "_SI.csv"
    out_loc = os.path.join(out_base_dir, out_name)

    ndata1 = NData()
    ndata2 = NData()
    grid = np.meshgrid(np.arange(ch), np.arange(ch), indexing='ij')
    stacked = np.stack(grid, 2)
    pairs = stacked.reshape(-1, 2)
    result_a = np.zeros(shape=pairs.shape[0], dtype=np.float32)

    for i, pair in enumerate(pairs):
        load_lfp(load_loc, pair[0], ndata1)
        load_lfp(load_loc, pair[1], ndata2)
        res = get_normalised_diff(ndata1.lfp.get_samples(),
                                  ndata2.lfp.get_samples())
        result_a[i] = res

    with open(out_loc, "w") as f:
        headers = [str(i) for i in range(1, ch + 1)]
        out_str = ",".join(headers)
        f.write(out_str)
        out_str = ""
        for i, (pair, val) in enumerate(zip(pairs, result_a)):
            if i % ch == 0:
                f.write(out_str + "\n")
                out_str = ""

            out_str += "{:.2f},".format(val)
            # f.write("({}, {}): {:.2f}\n".format(pair[0], pair[1], val))
        f.write(out_str + "\n")

    reshaped = np.reshape(result_a, newshape=[ch, ch])
    sns.heatmap(reshaped)
    plt.xticks(np.arange(0.5, ch + 0.5),
               labels=np.arange(1, ch + 1),
               fontsize=8)
    plt.xlabel('LFP Channels')
    plt.yticks(np.arange(0.5, ch + 0.5),
               labels=np.arange(1, ch + 1),
               fontsize=8)
    plt.ylabel('LFP Channels')
    plt.title('Raw LFP Similarity Index')
    fig_path = os.path.join(out_base_dir,
                            os.path.basename(fname) + "_LFP_SI.png")
    print("Saving figure to {}".format(fig_path))
    plt.savefig(fig_path, dpi=200, bbox_inches='tight', pad_inches=0.1)
    return result_a
def load_data(dir, spike_name, pos_name, lfp_name):
    spike_file = os.path.join(dir, spike_name)
    pos_file = os.path.join(dir, pos_name)
    lfp_file = os.path.join(dir, lfp_name)
    unit_no = 6
    ndata = NData()
    ndata.set_spike_file(spike_file)
    ndata.set_spatial_file(pos_file)
    ndata.set_lfp_file(lfp_file)
    ndata.load()
    ndata.set_unit_no(unit_no)
    return ndata