예제 #1
0
파일: analysis_tab.py 프로젝트: LBHB/NEMS
    def run_custom_group(self):

        analysis_name = self.lineCustomGroup.text()
        print('Custom group analysis: ', analysis_name)
        batch, selectedCellid, selectedModelname = self.get_selected()
        try:
            f = lookup_fn_at(analysis_name)
            f(selectedModelname, batch=int(batch), goodcells=selectedCellid)
        except:
            print('Unknown/incompatible analysis_name')
예제 #2
0
    def update_plot(self,
                    fn_path=None,
                    modelspec=None,
                    rec_name=None,
                    signal_names=None,
                    channels=None,
                    time_range=None,
                    emit=True,
                    **kwargs):
        """Updates members and plots."""
        if fn_path is not None:
            self.fn_path = fn_path
        if modelspec is not None:
            self.modelspec = modelspec
        if rec_name is not None:
            self.rec_name = rec_name
        if signal_names is not None:
            if len(signal_names) != 1:
                raise ValueError('NEMS can only plot a single signal.')
            self.signal_names = signal_names
        if channels is not None:
            self.channels = channels
        if time_range is not None:
            self.time_range = time_range

        plot_fn = lookup_fn_at(self.fn_path)
        plt.figure(f'{id(self)}')
        self.ax.clear()

        # fill the area
        pos = self.ax.get_position()
        pos.x0 = 0.1
        pos.x1 = 1
        self.ax.set_position(pos)

        rec = self.window().rec_container[self.rec_name]

        # sometimes the current channel index can be out of range
        try:
            plot_fn(rec=rec,
                    modelspec=self.modelspec,
                    sig_name=self.signal_names[0],
                    ax=self.ax,
                    channels=channels,
                    time_range=self.time_range,
                    **kwargs)
            self.canvas.figure.canvas.draw()
        except IndexError:
            self.parent().spinBox.setValue(0)
            return

        if emit:
            # TODO: how to know how many channels can be viewed here?
            #  but might not matter since plotting fn can handle channel input?
            self.sigChannelsChanged.emit(rec[self.signal_names[0]].shape[0])
예제 #3
0
파일: analysis_tab.py 프로젝트: LBHB/NEMS
    def run_custom_single(self):

        analysis_name = self.lineCustomSingle.text()
        print('Custom analysis: ', analysis_name)
        batch, selectedCellid, selectedModelname = self.get_selected()
        try:
            f = lookup_fn_at(analysis_name)
            xf, ctx = self.get_current_selection()
            f(**ctx)
            plt.show()
        except:
            print('Unknown/incompatible analysis_name')
예제 #4
0
    def on_action_custom_function(self):
        """Event handler for running custom function."""
        input_text, accepted = QInputDialog.getText(
            self,
            'Custom function',
            'Enter spec to custom function:',
            text=self.custom_fn)
        if not accepted or not input_text:
            return

        custom_fn = lookup_fn_at(input_text)

        status_text = f'Running custom function: "{input_text}".'
        log.info(status_text)
        self.statusbar.showMessage(status_text, 2000)

        self.custom_fn = input_text
        # custom functions must implement this spec
        custom_fn(cellid=self.cellid,
                  batch=self.batch,
                  modelname=self.modelname)
예제 #5
0
파일: modelbuilder.py 프로젝트: LBHB/NEMS
def modelspec2tf(modelspec,
                 seed=0,
                 use_modelspec_init=True,
                 fs=100,
                 initializer='random_normal',
                 freeze_layers=None,
                 kernel_regularizer=None):
    """
    Was in nems.modelspec, but this is so tightly coupled to tf libraries that
    it probably belongs here instead.

    """
    if kernel_regularizer is not None:
        regstr = kernel_regularizer.split(":")
        kernel_regularizer_ops = {}
        if len(regstr) > 2:
            if regstr[2] == 'firwc':
                kernel_regularizer_ops['modulenames'] = [
                    'weight_channels.basic', 'filter_bank'
                ]
        else:
            kernel_regularizer_ops['modulenames'] = [
                'weight_channels.basic', 'Conv2D', 'WeightChannelsNew',
                'state_dc_gain'
            ]

    layers = []
    if freeze_layers is None:
        freeze_layers = []

    for i, m in enumerate(modelspec):
        try:
            tf_layer = lookup_fn_at(m['tf_layer'])
        except KeyError:
            raise NotImplementedError(
                f'Layer "{m["fn"]}" does not have a tf equivalent.')

        if i in freeze_layers:
            trainable = False
        else:
            trainable = True

        if (kernel_regularizer is not None) and \
            any(fn in m['fn']  for fn in kernel_regularizer_ops['modulenames']):
            log.info(
                f"Including {kernel_regularizer} regularizer for {m['fn']}")
            layer = tf_layer.from_ms_layer(
                m,
                use_modelspec_init=use_modelspec_init,
                seed=seed,
                fs=fs,
                initializer=initializer,
                trainable=trainable,
                kernel_regularizer=kernel_regularizer)
        else:
            # don't pass kernel_regularizer (set to None) if not fir or weight chans
            layer = tf_layer.from_ms_layer(
                m,
                use_modelspec_init=use_modelspec_init,
                seed=seed,
                fs=fs,
                initializer=initializer,
                trainable=trainable)
        if ('Conv2D' in m['fn']):
            if 'offset' in m['phi'].keys():
                log.debug(f"Conv2D initializing offset: {m['phi']['offset']}")
            else:
                log.debug("Conv2D initializing offset pseudo-randomly")
        layers.append(layer)

    return layers