def import_from_json(self, str_json: str) -> bool:
     model_data = ModelData.from_json(str_json)
     if model_data is not None:
         self.beginResetModel()
         self._model_data = model_data
         self.endResetModel()
         return True
     else:
         return False
    def __init__(self):
        super(MainWindow, self).__init__()
        super(MainWindow, self).setupUi(self)

        self._model = DurationsListModel(ModelData(),
                                         self.tabv_intervals)
        self._json_path = None
        self.tabv_intervals.setModel(self._model)

        self.lbl_file_drop.changeFile.connect(
            self.on_lbl_file_drop_change_file
        )
        self._model.dataChanged.connect(self.on_model_data_changed)
        self._model.rowsInserted.connect(self.on_model_rows_inserted)
        self._model.rowsMoved.connect(self.on_model_rows_moved)
        self._model.rowsRemoved.connect(self.on_model_rows_removed)
        self._model.modelReset.connect(self.on_model_model_reset)

        self.update_source()
        self.update_output()
Esempio n. 3
0
 def _run_sim(self, model_path: str, sim_type: str, network_state: str, debug: bool, drop_list=None):
     """
     Run simulation to obtain position and possibly debug information
     :param model_path: Full path to the network model definitions
     :param sim_type: Indicator of simulation type
     :param network_state: Indicator of network state
     :param debug: If true return debug information as well
     :param drop_list: Optional det_drop dictionary of lists of units to keep or drop
     :return:
         [0]: n_steps x 3 matrix of position information at each step
         [1]: Debug dict if debug is true or None otherwise
     """
     mdata = ModelData(model_path)
     if network_state == "naive":
         chk = mdata.FirstCheckpoint
     else:
         chk = mdata.LastCheckpoint
     gpn = self.mo.network_model()
     gpn.load(mdata.ModelDefinition, chk)
     if sim_type == "r":
         sim = self.mo.rad_sim(gpn, self.std, **GlobalDefs.circle_sim_params)
     else:
         sim = self.mo.lin_sim(gpn, self.std, **GlobalDefs.lin_sim_params)
     sim.remove = drop_list
     if network_state == "bfevolve":
         ev_path = model_path + '/evolve/generation_weights.npy'
         weights = np.load(ev_path)
         w = np.mean(weights[-1, :, :], 0)
         sim.bf_weights = w
     elif network_state == "partevolve":
         # get weights after partial evolution - generation 7 is within initial decline
         ev_path = model_path + '/evolve/generation_weights.npy'
         weights = np.load(ev_path)
         w = np.mean(weights[7, :, :], 0)
         sim.bf_weights = w
     if network_state == "ideal":
         return sim.run_ideal(GlobalDefs.n_steps)
     else:
         return sim.run_simulation(GlobalDefs.n_steps, debug)
def get_cell_responses(temp, standards):
    """
    Loads a model and computes the temperature response of all neurons returning response matrix
    :return: n-timepoints x m-neurons matrix of responses
    """
    print("Select model directory")
    root = tk.Tk()
    root.update()
    root.withdraw()
    model_dir = filedialog.askdirectory(
        title="Select directory with model checkpoints",
        initialdir="./model_data/")
    mdata = ModelData(model_dir)
    root.update()
    # create our model and load from last checkpoint
    gpn = ZfGpNetworkModel()
    gpn.load(mdata.ModelDefinition, mdata.LastCheckpoint)
    # prepend lead-in to stimulus
    lead_in = np.full(gpn.input_dims[2] - 1, np.mean(temp[:10]))
    temp = np.r_[lead_in, temp]
    activities = gpn.unit_stimulus_responses(temp, None, None, standards)
    return np.hstack(activities['t']) if 't' in activities else np.hstack(
        activities['m'])
Esempio n. 5
0
 def _compute_cell_responses(self, model_dir, temp, network_id, drop_list=None):
     """
     Loads a model and computes the temperature response of all neurons returning response matrix
     :param model_dir: The directory of the network model
     :param temp: The temperature input to test on the network
     :param network_id: Numerical id of the network to later relate units back to a network
     :param drop_list: Optional det_drop dictionary of lists of units to keep or drop
     :return:
         [0]: n-timepoints x m-neurons matrix of responses
         [1]: 3 x m-neurons matrix with network_id in row 0, layer index in row 1, and unit index in row 2
     """
     mdata = ModelData(model_dir)
     gpn_trained = self.mo.network_model()
     gpn_trained.load(mdata.ModelDefinition, mdata.LastCheckpoint)
     # prepend lead-in to stimulus
     lead_in = np.full(gpn_trained.input_dims[2] - 1, np.mean(temp[:10]))
     temp = np.r_[lead_in, temp]
     act_dict = gpn_trained.unit_stimulus_responses(temp, None, None, self.std, det_drop=drop_list)
     if 't' in act_dict:
         activities = act_dict['t']
     else:
         activities = act_dict['m']
     activities = np.hstack(activities)
     # build id matrix
     id_mat = np.zeros((3, activities.shape[1]), dtype=np.int32)
     id_mat[0, :] = network_id
     if 't' in act_dict:
         hidden_sizes = [gpn_trained.n_units[0]] * gpn_trained.n_layers_branch
     else:
         hidden_sizes = [gpn_trained.n_units[1]] * gpn_trained.n_layers_mixed
     start = 0
     for layer, hs in enumerate(hidden_sizes):
         id_mat[1, start:start + hs] = layer
         id_mat[2, start:start + hs] = np.arange(hs, dtype=np.int32)
         start += hs
     return activities, id_mat
Esempio n. 6
0
                       tsin.size * GlobalDefs.frame_rate // 20)
 temperature = np.interp(xinterp, x, tsin)
 dfile.close()
 all_ids = []
 for i, p in enumerate(paths_512):
     cell_res, ids = ana.temperature_activity(mpath(p), temperature, i)
     all_ids.append(ids)
 all_ids = np.hstack(all_ids)
 clfile = h5py.File("ce_cluster_info.hdf5", "r")
 clust_ids = np.array(clfile["clust_ids"])
 clfile.close()
 train_ix = [0, 1]
 for i, p in enumerate(paths_512):
     np.random.shuffle(train_ix)
     model_path = mpath(p)
     mdata = ModelData(model_path)
     # t-branch retrain
     cel_folder = model_path + "/cel_tbranch_retrain"
     model = None
     dlist = a.create_det_drop_list(i, clust_ids, all_ids, ce_like)
     if os.path.exists(cel_folder):
         print(
             "Temperature branch retrain folder on model {0} already exists. Skipping."
             .format(p))
     else:
         os.mkdir(cel_folder)
         model = CeGpNetworkModel()
         model.load(mdata.ModelDefinition, mdata.LastCheckpoint)
         retrain(model, cel_folder, dlist, train_ix, lambda n: "_t_" in n)
     # m branch retrain
     cel_folder = model_path + "/cel_nontbranch_retrain"
 # load and interpolate temperature stimulus
 dfile = h5py.File("stimFile.hdf5", 'r')
 tsin = np.array(dfile['sine_L_H_temp'])
 x = np.arange(tsin.size)  # stored at 20 Hz !
 xinterp = np.linspace(0, tsin.size,
                       tsin.size * GlobalDefs.frame_rate // 20)
 temp = np.interp(xinterp, x, tsin)
 dfile.close()
 print("Select model directory")
 root = tk.Tk()
 root.update()
 root.withdraw()
 model_dir = filedialog.askdirectory(
     title="Select directory with model checkpoints",
     initialdir="./model_data/")
 mdata = ModelData(model_dir)
 root.update()
 # create our model and load from last checkpoint
 gpn = ZfGpNetworkModel()
 gpn.load(mdata.ModelDefinition, mdata.LastCheckpoint)
 # prepend lead-in to stimulus
 lead_in = np.full(gpn.input_dims[2] - 1, np.mean(temp[:10]))
 temp = np.r_[lead_in, temp]
 # run a short simulation to create some sample trajectories for speed and angle inputs
 sim = CircGradientTrainer(100, 22, 37)
 sim.p_move = 0.1 / GlobalDefs.frame_rate  # use reduced movement rate to aide visualization
 pos = sim.run_simulation(temp.size + 1)
 spd = np.sqrt(np.sum(np.diff(pos[:, :2], axis=0)**2, 1))
 da = np.diff(pos[:, 2])
 activities = gpn.unit_stimulus_responses(temp, spd, da, std)
 # make actual movie at five hertz, simply by skipping and also only create first repeat
class DurationsListModel(QAbstractTableModel):
    __regexp_moment = re.compile(r'([0-9]+):([0-5]?[0-9]):([0-5]?[0-9])')

    def __init__(self, model_data: ModelData, parent: QObject = None):
        super(DurationsListModel, self).__init__(parent)
        self._model_data = model_data

    def rowCount(self, parent: QModelIndex = None, *args, **kwargs) -> int:
        return self._model_data.intervals_size()

    def columnCount(self, parent: QModelIndex = None, *args, **kwargs) -> int:
        return 2

    def headerData(self,
                   section: int,
                   orientation: int,
                   role: int = None) -> QVariant:
        if role != Qt.DisplayRole:
            return QVariant()
        if orientation == Qt.Horizontal:
            if section == 0:
                return QVariant("Begin")
            elif section == 1:
                return QVariant("End")
            else:
                return QVariant()
        return super(DurationsListModel, self) \
            .headerData(section, orientation, role)

    def data(self, index: QModelIndex, role=None) -> QVariant:
        if not index.isValid():
            return QVariant()
        if 0 <= index.row() < self._model_data.intervals_size():
            if role == Qt.DisplayRole:
                row, col = index.row(), index.column()
                if col == 0 or col == 1:
                    t = self._model_data.get_interval_unwrap(row, col)
                    return QVariant('%02d:%02d:%02d' % (t[0], t[1], t[2]))
        return QVariant()

    def setData(self,
                index: QModelIndex,
                value: QVariant,
                role: int = None) -> bool:
        result = True
        if index.isValid() and role == Qt.EditRole:
            row, col = index.row(), index.column()
            if 0 <= row < self._model_data.intervals_size():
                if col == 0 or col == 1:
                    match_res = self.__regexp_moment.match(str(value))
                    if match_res:
                        hour, mins, secs = [int(i) for i in match_res.groups()]
                        interval = self._model_data.get_interval_unwrap(row)
                        if Moment.validate(hour, mins, secs):
                            if col == 0:
                                if [hour, mins, secs] > interval[1]:
                                    result = False
                                else:
                                    self._model_data \
                                        .set_interval(row,
                                                      begin=Moment(hour,
                                                                   mins,
                                                                   secs))
                            elif col == 1:
                                if [hour, mins, secs] < interval[0]:
                                    result = False
                                else:
                                    self._model_data \
                                        .set_interval(row,
                                                      end=Moment(hour,
                                                                 mins,
                                                                 secs))
                            else:
                                return False
                    else:
                        result = False
                else:
                    result = False
                if result:
                    self.dataChanged.emit(index, index, [role])
            else:
                result = False
        return result

    def flags(self, index: QModelIndex) -> Qt.ItemFlags:
        if not index.isValid():
            return Qt.ItemIsEnabled
        flags = super(DurationsListModel, self).flags(index)
        return flags | Qt.ItemIsEditable

    def insertRow(self,
                  row: int,
                  parent: QModelIndex = None,
                  *args,
                  **kwargs) -> bool:
        return self.insertRows(row, 1, parent, *args, **kwargs)

    def insertRows(self,
                   row: int,
                   count: int,
                   parent: QModelIndex = None,
                   *args,
                   **kwargs) -> bool:
        self.beginInsertRows(QModelIndex(), row, row + count - 1)
        for i in range(count):
            self._model_data.insert_interval(row + i, Moment(0, 0, 0),
                                             Moment(0, 0, 0))
        self.endInsertRows()
        return True

    def removeRow(self,
                  row: int,
                  parent: QModelIndex = None,
                  *args,
                  **kwargs) -> bool:
        return self.removeRows(row, 1, parent, *args, **kwargs)

    def removeRows(self,
                   row: int,
                   count: int,
                   parent: QModelIndex = None,
                   *args,
                   **kwargs) -> bool:
        result = True
        self.beginRemoveRows(QModelIndex(), row, row + count - 1)
        intervals_size = self._model_data.intervals_size()
        if 0 <= row < intervals_size \
                and count > 0 and row + count <= intervals_size:
            for i in reversed(range(row, row + count)):
                self._model_data.del_interval(i)
        else:
            result = False
        self.endRemoveRows()
        return result

    def add_interval(self, begin: Moment, end: Moment):
        if begin <= end:
            index = self.rowCount()
            self.beginInsertRows(QModelIndex(), index, index)
            self._model_data.add_interval(begin, end)
            self.endInsertRows()

    def move_interval(self, row: int, offset: int):
        row_src, row_dst = row, row + offset
        row_cnt = self.rowCount()
        if offset != 0 \
                and 0 <= row_src < row_cnt \
                and 0 <= row_dst < row_cnt:
            self.beginMoveRows(QModelIndex(), row_src, row_src, QModelIndex(),
                               row_dst if offset < 0 else row_dst + 1)
            self._model_data.move_interval(row, offset)
            self.endMoveRows()

    def clear_intervals(self):
        self.beginRemoveRows(QModelIndex(), 0, self.rowCount() - 1)
        self._model_data.clear_intervals()
        self.endRemoveRows()
        self.beginResetModel()
        self.endResetModel()

    def iter_intervals(self):
        return self._model_data.intervals_iter()

    @property
    def src_filename(self) -> str:
        return self._model_data.src_filename

    @src_filename.setter
    def src_filename(self, value: str):
        self._model_data.src_filename = value

    @property
    def src_path_dir(self) -> str:
        return self._model_data.src_path_dir

    @src_path_dir.setter
    def src_path_dir(self, value: str):
        self._model_data.src_path_dir = value

    @property
    def dst_path_dir(self) -> str:
        return self._model_data.dst_path_dir

    @dst_path_dir.setter
    def dst_path_dir(self, value: str):
        self._model_data.dst_path_dir = value

    # @property
    # def src_duration(self):
    #     return self._model_data.src_duration
    #
    # @src_duration.setter
    # def src_duration(self, value: int):
    #     self._model_data.src_duration = value
    #
    # @property
    # def src_size(self):
    #     return self._model_data.src_size
    #
    # @src_size.setter
    # def src_size(self, value: int):
    #     self._model_data.src_size = value

    def reset_data(self):
        self.clear_intervals()
        self._model_data = ModelData()

    def import_from_json(self, str_json: str) -> bool:
        model_data = ModelData.from_json(str_json)
        if model_data is not None:
            self.beginResetModel()
            self._model_data = model_data
            self.endResetModel()
            return True
        else:
            return False

    def export_to_json(self, *args, **kwargs) -> str:
        return self._model_data.to_json(*args, **kwargs)
 def reset_data(self):
     self.clear_intervals()
     self._model_data = ModelData()
Esempio n. 10
0
    sns.despine()
    fig.tight_layout()

    # plot cluster sizes
    fig, ax = pl.subplots()
    sns.countplot(clust_ids[clust_ids > -1], ax=ax)
    ax.set_ylabel("Cluster size")
    ax.set_xlabel("Cluster number")
    sns.despine(fig, ax)

    # plot white noise analysis of networks
    behav_kernels = {}
    k_names = ["stay", "straight", "left", "right"]
    for p in paths_512:
        m_path = mpath(p)
        mdata_wn = ModelData(m_path)
        gpn_wn = ZfGpNetworkModel()
        gpn_wn.load(mdata_wn.ModelDefinition, mdata_wn.LastCheckpoint)
        wna = WhiteNoiseSimulation(std, gpn_wn, stim_std=2)
        wna.switch_mean = 5
        wna.switch_std = 1
        ev_path = m_path + '/evolve/generation_weights.npy'
        weights = np.load(ev_path)
        w = np.mean(weights[-1, :, :], 0)
        wna.bf_weights = w
        kernels = wna.compute_behavior_kernels(10000000)
        for i, n in enumerate(k_names):
            if n in behav_kernels:
                behav_kernels[n].append(kernels[i])
            else:
                behav_kernels[n] = [kernels[i]]