예제 #1
0
class config(HasTraits):
    uuid = traits.Str(desc="UUID")

    # Directories
    working_dir = Directory(mandatory=True,
                            desc="Location of the Nipype working directory")
    base_dir = Directory(
        os.path.abspath('.'),
        mandatory=True,
        desc='Base directory of data. (Should be subject-independent)')
    sink_dir = Directory(mandatory=True,
                         desc="Location where the BIP will store the results")
    crash_dir = Directory(mandatory=False,
                          desc="Location to store crash files")
    surf_dir = Directory(
        desc="freesurfer directory. subject id's should be the same")
    save_script_only = traits.Bool(False)
    # Execution
    run_using_plugin = Bool(
        False,
        usedefault=True,
        desc="True to run pipeline with plugin, False to run serially")
    plugin = traits.Enum("PBS",
                         "MultiProc",
                         "SGE",
                         "Condor",
                         usedefault=True,
                         desc="plugin to use, if run_using_plugin=True")
    plugin_args = traits.Dict({"qsub_args": "-q many"},
                              usedefault=True,
                              desc='Plugin arguments.')
    test_mode = Bool(
        False,
        mandatory=False,
        usedefault=True,
        desc='Affects whether where and if the workflow keeps its \
                            intermediary files. True to keep intermediary files. '
    )
    timeout = traits.Float(30.0)
    # DataGrabber
    datagrabber = traits.Instance(Data, ())

    # segstats
    use_reg = traits.Bool(True)
    inverse_reg = traits.Bool(True)
    use_standard_label = traits.Bool(
        False, desc="use same label file for all subjects")
    label_file = traits.File()
    use_annotation = traits.Bool(
        False,
        desc=
        "use same annotation file for all subjects (will warp to subject space"
    )
    use_subject_annotation = traits.Bool(
        False,
        desc="you need to change datragrabber to\
                                           have outputs lh_annotation and rh_annotation"
    )
    annot_space = traits.String("fsaverage5",
                                desc="subject space of annot file")
    lh_annotation = traits.File()
    rh_annotation = traits.File()
    color_table_file = traits.Enum("Default", "Color_Table", "GCA_color_table",
                                   "None")
    color_file = traits.File()
    proj = traits.BaseTuple(("frac", 0, 1, 0.1), traits.Enum("abs", "frac"),
                            traits.Float(), traits.Float(), traits.Float())
    statname = traits.Str('segstats1', desc="description of the segstat")
예제 #2
0
class Component(t.HasTraits):
    __axes_manager = None

    active = t.Property(t.CBool(True))
    name = t.Property(t.Str(''))

    def __init__(self, parameter_name_list):
        self.events = Events()
        self.events.active_changed = Event("""
            Event that triggers when the `Component.active` changes.

            The event triggers after the internal state of the `Component` has
            been updated.

            Arguments
            ---------
            obj : Component
                The `Component` that the event belongs to
            active : bool
                The new active state
            """,
                                           arguments=["obj", 'active'])
        self.parameters = []
        self.init_parameters(parameter_name_list)
        self._update_free_parameters()
        self.active = True
        self._active_array = None
        self.isbackground = False
        self.convolved = True
        self.parameters = tuple(self.parameters)
        self._id_name = self.__class__.__name__
        self._id_version = '1.0'
        self._position = None
        self.model = None
        self.name = ''
        self._whitelist = {
            '_id_name': None,
            'name': None,
            'active_is_multidimensional': None,
            '_active_array': None,
            'active': None
        }
        self._slicing_whitelist = {'_active_array': 'inav'}
        self._slicing_order = (
            'active',
            'active_is_multidimensional',
            '_active_array',
        )

    _name = ''
    _active_is_multidimensional = False
    _active = True

    @property
    def active_is_multidimensional(self):
        return self._active_is_multidimensional

    @active_is_multidimensional.setter
    def active_is_multidimensional(self, value):
        if not isinstance(value, bool):
            raise ValueError('Only boolean values are permitted')

        if value == self.active_is_multidimensional:
            return

        if value:  # Turn on
            if self._axes_manager.navigation_size < 2:
                _logger.info('`navigation_size` < 2, skipping')
                return
            # Store value at current position
            self._create_active_array()
            self._store_active_value_in_array(self._active)
            self._active_is_multidimensional = True
        else:  # Turn off
            # Get the value at the current position before switching it off
            self._active = self.active
            self._active_array = None
            self._active_is_multidimensional = False

    def _get_name(self):
        return self._name

    def _set_name(self, value):
        old_value = self._name
        if old_value == value:
            return
        if self.model:
            for component in self.model:
                if value == component.name:
                    raise ValueError("Another component already has "
                                     "the name " + str(value))
            self._name = value
            setattr(self.model.components,
                    slugify(value, valid_variable_name=True), self)
            self.model.components.__delattr__(
                slugify(old_value, valid_variable_name=True))
        else:
            self._name = value
        self.trait_property_changed('name', old_value, self._name)

    @property
    def _axes_manager(self):
        return self.__axes_manager

    @_axes_manager.setter
    def _axes_manager(self, value):
        for parameter in self.parameters:
            parameter._axes_manager = value
        self.__axes_manager = value

    def _get_active(self):
        if self.active_is_multidimensional is True:
            # The following should set
            self.active = self._active_array[self._axes_manager.indices[::-1]]
        return self._active

    def _store_active_value_in_array(self, value):
        self._active_array[self._axes_manager.indices[::-1]] = value

    def _set_active(self, arg):
        if self._active == arg:
            return
        old_value = self._active
        self._active = arg
        if self.active_is_multidimensional is True:
            self._store_active_value_in_array(arg)
        self.events.active_changed.trigger(active=self._active, obj=self)
        self.trait_property_changed('active', old_value, self._active)

    def init_parameters(self, parameter_name_list):
        for name in parameter_name_list:
            parameter = Parameter()
            self.parameters.append(parameter)
            parameter.name = name
            parameter._id_name = name
            setattr(self, name, parameter)
            if hasattr(self, 'grad_' + name):
                parameter.grad = getattr(self, 'grad_' + name)
            parameter.component = self
            self.add_trait(name, t.Instance(Parameter))

    def _get_long_description(self):
        if self.name:
            text = '%s (%s component)' % (self.name, self._id_name)
        else:
            text = '%s component' % self._id_name
        return text

    def _get_short_description(self):
        text = ''
        if self.name:
            text += self.name
        else:
            text += self._id_name
        text += ' component'
        return text

    def __repr__(self):
        text = '<%s>' % self._get_long_description()
        return text

    def _update_free_parameters(self):
        self.free_parameters = sorted(
            [par for par in self.parameters if par.free], key=lambda x: x.name)
        self._nfree_param = sum(
            [par._number_of_elements for par in self.free_parameters])

    def update_number_parameters(self):
        i = 0
        for parameter in self.parameters:
            i += parameter._number_of_elements
        self.nparam = i
        self._update_free_parameters()

    def fetch_values_from_array(self, p, p_std=None, onlyfree=False):
        if onlyfree is True:
            parameters = self.free_parameters
        else:
            parameters = self.parameters
        i = 0
        for parameter in sorted(parameters, key=lambda x: x.name):
            length = parameter._number_of_elements
            parameter.value = (p[i] if length == 1 else p[i:i + length])
            if p_std is not None:
                parameter.std = (p_std[i] if length == 1 else tuple(
                    p_std[i:i + length]))

            i += length

    def _create_active_array(self):
        shape = self._axes_manager._navigation_shape_in_array
        if len(shape) == 1 and shape[0] == 0:
            shape = [
                1,
            ]
        if (not isinstance(self._active_array, np.ndarray)
                or self._active_array.shape != shape):
            _logger.debug('Creating _active_array for {}.\n\tCurrent array '
                          'is:\n{}'.format(self, self._active_array))
            self._active_array = np.ones(shape, dtype=bool)

    def _create_arrays(self):
        if self.active_is_multidimensional:
            self._create_active_array()
        for parameter in self.parameters:
            parameter._create_array()

    def store_current_parameters_in_map(self):
        for parameter in self.parameters:
            parameter.store_current_value_in_array()

    def fetch_stored_values(self, only_fixed=False):
        if self.active_is_multidimensional:
            # Store the stored value in self._active and trigger the connected
            # functions.
            self.active = self.active
        if only_fixed is True:
            parameters = (set(self.parameters) - set(self.free_parameters))
        else:
            parameters = self.parameters
        parameters = [
            parameter for parameter in parameters
            if (parameter.twin is None
                or not isinstance(parameter.twin, Parameter))
        ]
        for parameter in parameters:
            parameter.fetch()

    def plot(self, only_free=True):
        """Plot the value of the parameters of the model

        Parameters
        ----------
        only_free : bool
            If True, only the value of the parameters that are free will
             be plotted

        """
        if only_free:
            parameters = self.free_parameters
        else:
            parameters = self.parameters

        parameters = [k for k in parameters if k.twin is None]
        for parameter in parameters:
            parameter.plot()

    def export(self,
               folder=None,
               format="hspy",
               save_std=False,
               only_free=True):
        """Plot the value of the parameters of the model

        Parameters
        ----------
        folder : str or None
            The path to the folder where the file will be saved. If
            `None` the
            current folder is used by default.
        format : str
            The extension of the file format, default "hspy".
        save_std : bool
            If True, also the standard deviation will be saved.
        only_free : bool
            If True, only the value of the parameters that are free will
             be
            exported.

        Notes
        -----
        The name of the files will be determined by each the Component
        and
        each Parameter name attributes. Therefore, it is possible to
        customise
        the file names modify the name attributes.

        """
        if only_free:
            parameters = self.free_parameters
        else:
            parameters = self.parameters

        parameters = [k for k in parameters if k.twin is None]
        for parameter in parameters:
            parameter.export(
                folder=folder,
                format=format,
                save_std=save_std,
            )

    def summary(self):
        for parameter in self.parameters:
            dim = len(parameter.map.squeeze().shape) if parameter.map \
                is not None else 0
            if parameter.twin is None:
                if dim <= 1:
                    print('%s = %s ± %s %s' % (parameter.name, parameter.value,
                                               parameter.std, parameter.units))

    def __call__(self):
        """Returns the corresponding model for the current coordinates

        Returns
        -------
        numpy array
        """

        axis = self.model.axis.axis[self.model.channel_switches]
        component_array = self.function(axis)
        return component_array

    def _component2plot(self, axes_manager, out_of_range2nans=True):
        old_axes_manager = None
        if axes_manager is not self.model.axes_manager:
            old_axes_manager = self.model.axes_manager
            self.model.axes_manager = axes_manager
            self.fetch_stored_values()
        s = self.__call__()
        if not self.active:
            s.fill(np.nan)
        if self.model.signal.metadata.Signal.binned is True:
            s *= self.model.signal.axes_manager.signal_axes[0].scale
        if old_axes_manager is not None:
            self.model.axes_manager = old_axes_manager
            self.charge()
        if out_of_range2nans is True:
            ns = np.empty(self.model.axis.axis.shape)
            ns.fill(np.nan)
            ns[self.model.channel_switches] = s
            s = ns
        if old_axes_manager is not None:
            self.model.axes_manager = old_axes_manager
            self.fetch_stored_values()
        return s

    def set_parameters_free(self, parameter_name_list=None):
        """
        Sets parameters in a component to free.

        Parameters
        ----------
        parameter_name_list : None or list of strings, optional
            If None, will set all the parameters to free.
            If list of strings, will set all the parameters with the same name
            as the strings in parameter_name_list to free.

        Examples
        --------
        >>> v1 = hs.model.components1D.Voigt()
        >>> v1.set_parameters_free()
        >>> v1.set_parameters_free(parameter_name_list=['area','centre'])

        See also
        --------
        set_parameters_not_free
        hyperspy.model.BaseModel.set_parameters_free
        hyperspy.model.BaseModel.set_parameters_not_free
        """

        parameter_list = []
        if not parameter_name_list:
            parameter_list = self.parameters
        else:
            for _parameter in self.parameters:
                if _parameter.name in parameter_name_list:
                    parameter_list.append(_parameter)

        for _parameter in parameter_list:
            _parameter.free = True

    def set_parameters_not_free(self, parameter_name_list=None):
        """
        Sets parameters in a component to not free.

        Parameters
        ----------
        parameter_name_list : None or list of strings, optional
            If None, will set all the parameters to not free.
            If list of strings, will set all the parameters with the same name
            as the strings in parameter_name_list to not free.

        Examples
        --------
        >>> v1 = hs.model.components1D.Voigt()
        >>> v1.set_parameters_not_free()
        >>> v1.set_parameters_not_free(parameter_name_list=['area','centre'])

        See also
        --------
        set_parameters_free
        hyperspy.model.BaseModel.set_parameters_free
        hyperspy.model.BaseModel.set_parameters_not_free
        """

        parameter_list = []
        if not parameter_name_list:
            parameter_list = self.parameters
        else:
            for _parameter in self.parameters:
                if _parameter.name in parameter_name_list:
                    parameter_list.append(_parameter)

        for _parameter in parameter_list:
            _parameter.free = False

    def _estimate_parameters(self, signal):
        if self._axes_manager != signal.axes_manager:
            self._axes_manager = signal.axes_manager
            self._create_arrays()

    def as_dictionary(self, fullcopy=True):
        """Returns component as a dictionary
        For more information on method and conventions, see
        :meth:`hyperspy.misc.export_dictionary.export_to_dictionary`
        Parameters
        ----------
        fullcopy : Bool (optional, False)
            Copies of objects are stored, not references. If any found,
            functions will be pickled and signals converted to dictionaries
        Returns
        -------
        dic : dictionary
            A dictionary, containing at least the following fields:
            parameters : list
                a list of dictionaries of the parameters, one per
            _whitelist : dictionary
                a dictionary with keys used as references saved attributes, for
                more information, see
                :meth:`hyperspy.misc.export_dictionary.export_to_dictionary`
            * any field from _whitelist.keys() *
        """
        dic = {
            'parameters': [p.as_dictionary(fullcopy) for p in self.parameters]
        }
        export_to_dictionary(self, self._whitelist, dic, fullcopy)
        return dic

    def _load_dictionary(self, dic):
        """Load data from dictionary.
        Parameters
        ----------
        dict : dictionary
            A dictionary containing following items:
            _id_name : string
                _id_name of the original component, used to create the
                dictionary. Has to match with the self._id_name
            parameters : list
                A list of dictionaries, one per parameter of the component (see
                parameter.as_dictionary() documentation for more)
            _whitelist : dictionary
                a dictionary, which keys are used as keywords to match with the
                component attributes.  For more information see
                :meth:`hyperspy.misc.export_dictionary.load_from_dictionary`
            * any field from _whitelist.keys() *
        Returns
        -------
        twin_dict : dictionary
            Dictionary of 'id' values from input dictionary as keys with all of
            the parameters of the component, to be later used for setting up
            correct twins.
        """

        if dic['_id_name'] == self._id_name:
            load_from_dictionary(self, dic)
            id_dict = {}
            for p in dic['parameters']:
                idname = p['_id_name']
                if hasattr(self, idname):
                    par = getattr(self, idname)
                    t_id = par._load_dictionary(p)
                    id_dict[t_id] = par
                else:
                    raise ValueError(
                        "_id_name of parameters in component and dictionary do not match"
                    )
            return id_dict
        else:
            raise ValueError(
                "_id_name of component and dictionary do not match, \ncomponent._id_name = %s\
                    \ndictionary['_id_name'] = %s" %
                (self._id_name, dic['_id_name']))
예제 #3
0
class Smoothing(t.HasTraits):
    # The following is disabled because as of traits 4.6 the Color trait
    # imports traitsui (!)
    # try:
    #     line_color = t.Color("blue")
    # except ModuleNotFoundError:
    #     # traitsui is required to define this trait so it is not defined when
    #     # traitsui is not installed.
    #     pass
    line_color_ipy = t.Str("blue")
    differential_order = t.Int(0)

    @property
    def line_color_rgb(self):
        if hasattr(self, "line_color"):
            try:
                # PyQt and WX
                return np.array(self.line_color.Get()) / 255.
            except AttributeError:
                try:
                    # PySide
                    return np.array(self.line_color.getRgb()) / 255.
                except BaseException:
                    return matplotlib.colors.to_rgb(self.line_color_ipy)
        else:
            return matplotlib.colors.to_rgb(self.line_color_ipy)

    def __init__(self, signal):
        self.ax = None
        self.data_line = None
        self.smooth_line = None
        self.signal = signal
        self.single_spectrum = self.signal.get_current_signal().deepcopy()
        self.axis = self.signal.axes_manager.signal_axes[0].axis
        self.plot()

    def plot(self):
        if self.signal._plot is None or not self.signal._plot.is_active:
            self.signal.plot()
        hse = self.signal._plot
        l1 = hse.signal_plot.ax_lines[0]
        self.original_color = l1.line.get_color()
        l1.set_line_properties(color=self.original_color, type='scatter')
        l2 = drawing.signal1d.Signal1DLine()
        l2.data_function = self.model2plot

        l2.set_line_properties(color=self.line_color_rgb, type='line')
        # Add the line to the figure
        hse.signal_plot.add_line(l2)
        l2.plot()
        self.data_line = l1
        self.smooth_line = l2
        self.smooth_diff_line = None

    def update_lines(self):
        self.smooth_line.update()
        if self.smooth_diff_line is not None:
            self.smooth_diff_line.update()

    def turn_diff_line_on(self, diff_order):

        self.signal._plot.signal_plot.create_right_axis()
        self.smooth_diff_line = drawing.signal1d.Signal1DLine()
        self.smooth_diff_line.axes_manager = self.signal.axes_manager
        self.smooth_diff_line.data_function = self.diff_model2plot
        self.smooth_diff_line.set_line_properties(color=self.line_color_rgb,
                                                  type='line')
        self.signal._plot.signal_plot.add_line(self.smooth_diff_line,
                                               ax='right')

    def _line_color_ipy_changed(self):
        if hasattr(self, "line_color"):
            self.line_color = str(self.line_color_ipy)
        else:
            self._line_color_changed(None, None)

    def turn_diff_line_off(self):
        if self.smooth_diff_line is None:
            return
        self.smooth_diff_line.close()
        self.smooth_diff_line = None

    def _differential_order_changed(self, old, new):
        if new == 0:
            self.turn_diff_line_off()
            return
        if old == 0:
            self.turn_diff_line_on(new)
            self.smooth_diff_line.plot()
        else:
            self.smooth_diff_line.update(force_replot=False)

    def _line_color_changed(self, old, new):
        self.smooth_line.line_properties = {'color': self.line_color_rgb}
        if self.smooth_diff_line is not None:
            self.smooth_diff_line.line_properties = {
                'color': self.line_color_rgb
            }
        try:
            # it seems that changing the properties can be done before the
            # first rendering event, which can cause issue with blitting
            self.update_lines()
        except AttributeError:
            pass

    def diff_model2plot(self, axes_manager=None):
        smoothed = np.diff(self.model2plot(axes_manager),
                           self.differential_order)
        return smoothed

    def close(self):
        if self.signal._plot.is_active:
            if self.differential_order != 0:
                self.turn_diff_line_off()
            self.smooth_line.close()
            self.data_line.set_line_properties(color=self.original_color,
                                               type='line')
예제 #4
0
class Data(HasTraits):
    fields = traits.List(traits.Instance(DataBase, ()))
    base_directory = Directory(os.path.abspath('.'))
    template = traits.Str('*')
    template_args = traits.Dict({"a": "b"}, usedefault=True)
    field_template = traits.Dict({"key": ["hi"]}, usedefault=True)

    if use_view:
        check = traits.Button("Check")
        view = get_view()

    def __init__(self, outfields=None):
        if outfields:
            d_ft = {}
            d_ta = {}
            for out in outfields:
                d_ft[out] = '%s'
                d_ta[out] = [['name']]
            self.field_template = d_ft
            self.template_args = d_ta
            self.outfields = outfields

    def _get_infields(self):
        infields = []
        for f in self.fields:
            infields.append(f.name)
        return infields

    def _add_iterable(self, field):
        import nipype.interfaces.utility as niu
        import nipype.pipeline.engine as pe
        it = pe.Node(niu.IdentityInterface(fields=[field.name]),
                     name=field.name + "_iterable")
        it.iterables = (field.name, field.values)
        return it

    def _set_inputs(self):
        self._node_added = False
        set_dict = {}
        for f in self.fields:
            if not f.iterable:
                set_dict[f.name] = f.values
            else:
                it = self._add_iterable(f)
                self._node_added = True
                self._wk.connect(it, f.name, self._dg, f.name)
        self._dg.inputs.trait_set(**set_dict)

    def create_dataflow(self):
        import nipype.interfaces.io as nio
        import nipype.pipeline.engine as pe
        self._wk = pe.Workflow(name='custom_datagrabber')
        self._dg = pe.Node(nio.DataGrabber(outfields=self.outfields,
                                           infields=self._get_infields()),
                           name='datagrabber')
        self._set_inputs()
        self._dg.inputs.base_directory = self.base_directory
        self._dg.inputs.field_template = self.field_template
        self._dg.inputs.template_args = self.template_args
        self._dg.inputs.template = self.template
        if not self._node_added:
            self._wk.add_nodes([self._dg])
        return self._wk

    def get_fields(self):
        foo = self.get()
        d = {}
        for key, item in foo.iteritems():
            if not key.startswith('_'):
                if isinstance(item, list):
                    d[key] = []
                    for it in item:
                        if isinstance(it, DataBase):
                            d[key].append(it.get())
                        else:
                            d[key].append(it)
                else:
                    d[key] = item
        return d

    def set_fields(self, d):
        for key, item in d.iteritems():
            if not key == "fields":
                self.set(**{key: item})
            else:
                foo = []
                for f in item:
                    tmp = DataBase()
                    tmp.set(**f)
                    foo.append(tmp)
                self.set(**{key: foo})

    def _check_fired(self):
        dg = self.create_dataflow()
        dg.run()
예제 #5
0
 def configure_controller(cls):
     c = Controller()
     c.add_trait('param_type', traits.Str('Str'))
     return c
예제 #6
0
class DACChannel(traits.HasTraits):
    def __init__(self, setchannelName, port, connection, rpiADC):
        self.channelName = setchannelName
        self.x = ad.ADEvalBC()
        self.port = port
        self.connection = connection
        self.rpiADC = rpiADC
        self.wmLockTimer = Timer(1000, self.wmLock2)
        self.wmLockTimer.Stop()
#        time.sleep(5)
#        self.start_timer()

    update = traits.Button()
    set = traits.Button()

    pinMode = '0'
    relockMode = traits.Enum("Manual Mode", "Doubling cavity Relock",
                             "Wavemeter Relock", "Wavemeter Lock")
    #set_Relock = traits.Button()

    #---- MANUAL MODE ----#

    voltage = traits.Float(desc="Voltage of the Channel")
    setVoltage = traits.Float(desc="set Voltage")
    powerModeMult = 0
    channelName = traits.Str(desc="Name")
    channelDescription = traits.Str(desc="Description")
    powerMode = traits.Enum(
        5,
        10,
        10.8,
        desc="power Mode",
    )
    bipolar = traits.Bool(desc="Bipolarity")
    channelMessage = traits.Str()
    bytecode = traits.Str()
    port = traits.Str()

    def pinSet(self, channel, mode):
        cmd = "pin=" + channel + mode
        self.connection.send(cmd)
        #print cmd

    def _update_fired(self):
        if self.bipolar == True:
            a = "bipolar"
            bip = True
            self.powerModeMult = -1
        else:
            a = "unipolar"
            bip = False
            self.powerModeMult = 0

        cmd = "cmd=" + self.x.setMaxValue(self.port, self.powerMode, bip)
        #print cmd
        self.connection.send(cmd)
        b = "Mode set to %.1f" % self.powerMode
        self.channelMessage = b + ' ' + a
        self._set_fired()

    def _set_fired(self):
        if ((self.setVoltage > self.powerMode) and (self.bipolar == False)):
            print "setVoltage out of bounds. Not sending."
        elif ((abs(self.setVoltage) > self.powerMode)
              and (self.bipolar == True)):
            print "setVoltage out of bounds. Not sending."
        else:
            cmd = "cmd=" + self.x.generate_voltage(
                self.port, self.powerMode, self.powerModeMult * self.powerMode,
                self.setVoltage)
            self.connection.send(cmd)
            self.bytecode = self.x.generate_voltage(
                self.port, self.powerMode, self.powerModeMult * self.powerMode,
                self.setVoltage)
            self.voltage = self.setVoltage

#---- MANUAL MODE GUI ----#

    voltageGroup = traitsui.HGroup(traitsui.VGroup(
        traitsui.Item('voltage',
                      label="Measured Voltage",
                      style_sheet='* { font-size: 18px;  }',
                      style="readonly"),
        traitsui.Item('setVoltage', label="Set Value"),
    ),
                                   traitsui.Item('set', show_label=False),
                                   show_border=True)

    powerGroup = traitsui.VGroup(traitsui.HGroup(
        traitsui.Item('powerMode', label="Power Mode"),
        traitsui.Item('bipolar'),
    ),
                                 traitsui.HGroup(
                                     traitsui.Item('update', show_label=False),
                                     traitsui.Item('channelMessage',
                                                   show_label=False,
                                                   style="readonly"),
                                 ),
                                 traitsui.Item('bytecode',
                                               show_label=False,
                                               style="readonly"),
                                 traitsui.Item('switch_Lock'),
                                 show_border=True)

    manualGroup = traitsui.VGroup(traitsui.Item('channelName',
                                                label="Channel Name",
                                                style="readonly"),
                                  voltageGroup,
                                  show_border=True,
                                  visible_when='relockMode == "Manual Mode"')

    #---- DOUBLING CAVITY MODE GUI----#

    adcChannel = traits.Enum(0,
                             1,
                             2,
                             3,
                             4,
                             5,
                             6,
                             7,
                             8,
                             9,
                             10,
                             11,
                             12,
                             desc="Channel of the rpiADC")
    adcVoltage = traits.Float(desc="Voltage on rpiADC Channel")
    DCscan_and_lock = traits.Button()
    switch_Lock = traits.Button()
    DCconnect = traits.Bool()
    DCautolock = traits.Button()
    DCadcVoltages = None
    DCadcVoltagesMean = None
    DCtolerance = 0.01  #Volt
    DCmistakeCounter = 0
    DCminBoundary = traits.Float()
    DCmaxBoundary = traits.Float()
    DCnoOfSteps = traits.Int()

    #Updates voltage of the selected channel"
    def _adcVoltage_update(self):
        self.adcVoltage = self._adcVoltage_get()

#Gets voltage of the selected channel via rpiADC Client

    def _adcVoltage_get(self):
        return self.rpiADC.getResults()[self.adcChannel]
#        print "latest results = %s " % self.rpiADC.latestResults
#        self.rpiADC.getResults()
#        print "latest results = %s " % self.rpiADC.latestResults
#        if self.adcChannel in self.rpiADC.latestResults:
#            return self.rpiADC.latestResults[self.adcChannel]
#        else:
#            return -999

#As soon as the connect button is checked, automatic updating of the adc voltage is initiated.
#When it is unchecked, the update stops

    def _DCconnect_changed(self):

        if self.DCconnect == True:
            self._start_PD_timer()
        else:
            self.PDtimer.stop()
            self.adcVoltage = -999
            print "PD timer stopped."

#Starts the timer, that only updates the displayed voltage

    def _start_PD_timer(self):
        self.PDtimer = Timer(1000.0, self._adcVoltage_update)

    #Starts a timer, that updates the displayed voltage, as well as does the "in lock" checking
    def _start_PD_lock_timer(self):
        self.PDtimer = Timer(1000.0, self.update_PD_and_Lock)

#Controls if everything is still in lock. Also updates displayed voltage. It counts still in lock, when the measured frequency
#is within DC tolerance of the mean of the last five measured frequencies.

    def update_PD_and_Lock(self):
        self._adcVoltage_update()  #still display Voltage
        pdVoltage = self._adcVoltage_get()
        #print "Updated Frequency"
        #mistakeCounter = 0
        if len(self.DCadcVoltages
               ) < 5:  #number of Measurements that will be compared
            print "Getting Data for Lock. Do not unlock!"
            self.DCadcVoltages = np.append(self.DCadcVoltages, pdVoltage)
        else:
            self.DCadcVoltagesMean = np.mean(
                self.DCadcVoltages)  #Mean Frequency to compare to
            if (abs(pdVoltage - self.DCadcVoltagesMean) < self.DCtolerance):
                self.DCadcVoltages = np.append(self.DCadcVoltages, pdVoltage)
                self.DCadcVoltages = np.delete(
                    self.DCadcVoltages, 0)  #keep Array at constant length
                print "Still locked."
                if self.DCmistakeCounter > 0:
                    self.DCmistakeCounter = 0
            else:
                self.DCmistakeCounter += 1
                if self.DCmistakeCounter > 5:
                    self.PDtimer.stop()
                    self._start_PD_timer()  #keep Frequency on display..
                    self._DCscan_and_lock_fired()
                    self._DCautolock_fired()
                else:
                    print self.DCmistakeCounter

    #This button is used, when everything is already locked. It prepares the voltage mean array, stops the PD timer and starts the update_PD_and_Lock timer routine
    def _DCautolock_fired(self):
        if self.DCconnect == True:
            self.DCadcVoltages = np.array([])
            self.PDtimer.stop()
            self._start_PD_lock_timer()
        else:
            print "No adcRPi connected."

#This function (button) scans the voltage and reads the RPiADC voltage of the selected channel. It subsequently attempts
#to lock the Cavity at the voltage, where the RPiADC voltage is highest. The Algorithm for the lock is as follows:
#1) It does a coarse DAC voltage scan with the parameters selected in the input (minBoundary etc). This scan is terminated when
#	the adc voltage goes above a threshold (3.2 V - hardcoded, maybe add input), or after scanning the whole range.
#2) A fine scan 0.3V(dac) around the maximum (or around the 3.2V dac voltage) is done. Currently the number of steps is hardcoded to 600,
#	which worked well. This scan terminated either by going to a 3.3V (adc) threshhold or after finishing.
#3) The Cavity is locked at a dac voltage that corresponds to either the maximum after the whole scan or the 3.3V threshold.

    def _DCscan_and_lock_fired(self):
        self.pinSet(self.port, '1')  #Unlock
        time.sleep(
            1
        )  #If the cavity was in lock before we can not directly start scanning, or else it will read 3.3V as first value
        voltages = np.linspace(
            self.DCminBoundary, self.DCmaxBoundary,
            self.DCnoOfSteps)  #Parameters for first scan set by GUI
        diodeVoltages = np.array([])
        for entry in voltages:  #First scan
            self.setVoltage = entry
            self._set_fired()  #update setVoltage
            diodeVoltages = np.append(diodeVoltages, self._adcVoltage_get())
            volt = self._adcVoltage_get()
            print volt
            if volt >= 3.2:  #Threshold for first scan
                break
            time.sleep(
                0.05
            )  #sleep time between every step for adc readout, maybe it is possible to make the scan faster
        self.setVoltage = voltages[diodeVoltages.argmax(
            axis=0)]  #DAC Voltage corresponding to the highest adc voltage
        print "Attempting to reduce scan Range"
        print self.setVoltage
        time.sleep(2.0)
        if self.setVoltage < 0.3:  #make sure we do not scan below zero, as the lock box does not like negative voltages.
            auxSetVolt = 0
        else:
            auxSetVolt = self.setVoltage - 0.3
        voltages = np.linspace(auxSetVolt, self.setVoltage + 0.3,
                               600)  #parameters for second scan
        diodeVoltages = np.array([])
        for entry in voltages:  #Second scan
            if entry > self.powerMode:  #Our voltage can not go above our maximum value
                break
            self.setVoltage = entry
            self._set_fired()
            diodeVoltages = np.append(diodeVoltages, self._adcVoltage_get())
            volt = self._adcVoltage_get()
            print volt
            if volt >= 3.3:  #threshold for second scan
                print self.setVoltage
                self.pinSet(self.port, '0')
                return
            time.sleep(0.1)
        self.setVoltage = voltages[diodeVoltages.argmax(axis=0)]
        print "DAC Voltage set to %f" % voltages[diodeVoltages.argmax(axis=0)]
        print ".. this corresponds to a diode Voltage of %f" % diodeVoltages[
            diodeVoltages.argmax(axis=0)]
        self._set_fired()
        time.sleep(0.2)
        self.pinSet(self.port, '0')  #Lock
        print "Voltage set to %f to attempt relock." % self.setVoltage
        return

    #Changes from lock to dither and other way around
    def _switch_Lock_fired(self):
        if self.pinMode == '0':
            self.pinMode = '1'
            self.pinSet(self.port, self.pinMode)
        else:
            self.pinMode = '0'
            self.pinSet(self.port, self.pinMode)

#Gui Stuff#

    DCgroup = traitsui.VGroup(
        traitsui.HGroup(
            traitsui.VGroup(
                traitsui.Item('adcChannel'),
                traitsui.Item('adcVoltage', style='readonly'),
            ),
            traitsui.VGroup(
                traitsui.Item('DCmaxBoundary'),
                traitsui.Item('DCminBoundary'),
                traitsui.Item('DCnoOfSteps'),
            ),
        ),
        #traitsui.Item('switch_Lock'),
        traitsui.HGroup(
            traitsui.VGroup(
                traitsui.Item('DCscan_and_lock'),
                traitsui.Item('DCautolock'),
            ), traitsui.Item('DCconnect')),
        visible_when='relockMode == "Doubling cavity Relock"',
        show_border=True,
    )

    #---- WAVEMETER GUI ----#
    wavemeter = traits.Enum("Humphry", "Ion Cavity")
    wmChannel = traits.Enum(1, 2, 3, 4, 5, 6, 7, 8)
    wmFrequency = traits.Float()
    wmConnected = traits.Bool()
    wmHWI = None
    wmIP = None
    wmPort = None
    wmReLock = traits.Button()
    wmFrequencyLog = np.array([])
    wmMeanFrequency = None
    wmTolerance = 0.0000006  #Frequency tolerance in THz, set to 60 MHz in accordance with wm accuracy
    mistakeCounter = 0
    #wmPolarity = traits.Enum("+", "-")
    wmEmptyMemory = traits.Button()

    #When hitting wmConnected, a connection to the wavemeter and readout is established
    def _wmConnected_changed(self):
        if self.wmConnected == True:
            if self.wavemeter == "Humphry":
                self.wmIP = '192.168.0.111'
                self.wmPort = 6101
            elif self.wavemeter == "Ion Cavity":
                self.wmIP = '192.168.32.2'
                self.wmPort = 6100

            self.wmHWI = PyHWI.DECADSClientConnection('WLM', self.wmIP,
                                                      self.wmPort)
            #frequencyArray = wmHWI.getFrequency(True, self.wmChannel)
            self.start_timer()
        else:
            self.wmFrequency = -999
            self.timer.Stop()  #stops either lock_timer or read_timer.
            print "Timer stopped"

    #GUI stuff
    wmGroup = traitsui.VGroup(traitsui.Item('wmFrequency', style='readonly'),
                              traitsui.Item('wmConnected'),
                              traitsui.Item('wmReLock'),
                              traitsui.HGroup(
                                  traitsui.Item('wavemeter'),
                                  traitsui.Item('wmChannel'),
                                  traitsui.Item('wmEmptyMemory',
                                                show_label=False)),
                              visible_when='relockMode == "Wavemeter Relock"')

    #Resets the memory of the lock (the array which saves the last read frequencies)
    def _wmEmptyMemory_fired(self):
        self.wmFrequencyLog = np.array([])
        print "Memory empty."

#starts readout-only timer

    def start_timer(self):
        print "Timer started"
        self.timer = Timer(1000.0, self.update_wm)

#starts readout-and-lock timer, analogue to the doubling cavity case

    def start_lock_timer(self):
        print "Lock timer started"
        self.timer = Timer(1000.0, self.update_wm_and_lock)

#updates the displayed frequency

    def update_wm(self):
        frequencyArray = self.wmHWI.getFrequency(True, self.wmChannel)
        self.wmFrequency = frequencyArray[1]

#updates frequency and checks if its still near the mean of the last five (hardcoded, see below) measured frequencies. If not, it attempts a relock.

    def update_wm_and_lock(self):
        self.update_wm()
        frequencyArray = self.wmHWI.getFrequency(True, self.wmChannel)
        #print "Updated Frequency"
        #mistakeCounter = 0
        if len(self.wmFrequencyLog
               ) < 5:  #number of Measurements that will be compared
            print "Getting Data for Lock. Do not unlock!"
            self.wmFrequencyLog = np.append(self.wmFrequencyLog,
                                            frequencyArray[1])
        else:
            self.wmMeanFrequency = np.mean(
                self.wmFrequencyLog)  #Mean Frequency to compare to
            if (abs(frequencyArray[1] - self.wmMeanFrequency) <
                    self.wmTolerance):
                self.wmFrequencyLog = np.append(self.wmFrequencyLog,
                                                frequencyArray[1])
                self.wmFrequencyLog = np.delete(
                    self.wmFrequencyLog, 0)  #keep Array at constant length
                print "Still locked."
                if self.mistakeCounter > 0:
                    self.mistakeCounter = 0
            else:
                self.mistakeCounter += 1
                if self.mistakeCounter > 5:  #number of measurements that still count as locked, though the frequency is not within boundaries
                    self.timer.stop()
                    self.start_timer()  #keep Frequency on display..
                    self.wmRelock(self.wmMeanFrequency)
                else:
                    print self.mistakeCounter

    #Relock procedure.
#For now this scans only one time, with a hardcoded number of steps. It might be worth modifying this to look like the doubling cavity relock procedure.

    def wmRelock(self, wantedFrequency):
        self.pinSet(self.port, '1')
        voltages = np.linspace(self.powerModeMult * self.powerMode,
                               self.powerMode, 10)
        wmRelockTry = 0
        try:
            while (wmRelockTry < 5):  #attempt relock five times
                for entry in voltages:
                    self.setVoltage = entry
                    self._set_fired()
                    time.sleep(1.0)
                    frequencyArray = self.wmHWI.getFrequency(
                        True, self.wmChannel)
                    if (abs(frequencyArray[1] - wantedFrequency) <
                            self.wmTolerance):
                        print "Relock_attempt!"
                        self.pinSet(self.port, '0')
                        self._wmLock_fired()
                        raise GetOutOfLoop  #Opens the function again (inductively). Maybe fix that by going back
                        #to the level above somehow.
                wmRelockTry += 1
                print "Relock try %f not succesful" % wmRelockTry
            print "Was not able to Relock."
        except GetOutOfLoop:
            print "gotOutOfLoop"
            pass

    def _wmReLock_fired(self):
        #self.wmFrequencyLog = np.array([])
        self.mistakeCounter = 0
        if self.wmConnected == True:
            self.timer.Stop(
            )  #Switch from read-only timer to read-and-log timer. If both run at the same time
            self.start_lock_timer()  #we run into timing problems.
        else:
            print "No Wavemeter connected!"

        print "hi"

#---- WAVEMETERLOCK - DE 17092018 GUI ----#

    wavemeter = traits.Enum("Humphry", "Ion Cavity")
    wmChannel = traits.Enum(1, 2, 3, 4, 5, 6, 7, 8)
    wmFrequency = traits.Float()
    wmConnected = traits.Bool()
    wmHWI = None
    wmIP = None
    wmPort = None
    isRunning = traits.Bool(False)
    wmVoltage = 0
    wmLockTimer = traits.Instance(Timer)
    wmLockStart = traits.Button()
    wmLockStop = traits.Button()
    wmFrequencyLog = np.array([])
    wmMeanFrequency = None
    wmTargetFrequency = traits.Float()  #frequency to lock
    wmGain = traits.Float()  #Gain for wavemeterlock
    wmTolerance = 0.0000006  #Frequency tolerance in THz, set to 60 MHz in accordance with wm accuracy
    mistakeCounter = 0
    #wmPolarity = traits.Enum("+", "-")
    wmEmptyMemory = traits.Button()

    #GUI stuff
    wmlockGroup = traitsui.VGroup(
        traitsui.HGroup(traitsui.Item('wavemeter'), traitsui.Item('wmChannel'),
                        traitsui.Item('wmEmptyMemory', show_label=False)),
        traitsui.Item('wmFrequency', style='readonly'),
        traitsui.Item('wmConnected'),
        traitsui.Item('wmTargetFrequency'),
        traitsui.Item('wmGain'),
        traitsui.HGroup(
            traitsui.Item('wmLockStart', visible_when="isRunning == False"),
            traitsui.Item('wmLockStop', visible_when="isRunning == True")),
        visible_when='relockMode == "Wavemeter Lock"')

    def _wmLockStart_fired(self):
        print "Start: Wavemeterlock"
        self.isRunning = True
        self.wmLockTimer.Start()

    def wmLock2(self):
        # Calculate error in MHz
        error = (self.wmFrequency - self.wmTargetFrequency) * 10**3
        if abs(error) < 5000:
            self.wmVoltage = self.wmVoltage + error * self.wmGain
            if self.wmVoltage > 10:
                self.wmVoltage = 10
            elif self.wmVoltage < -10:
                self.wmVoltage = -10

        cmd = "cmd=" + self.x.generate_voltage(
            self.port, self.powerMode, self.powerModeMult * self.powerMode,
            self.wmVoltage)
        self.connection.sendwithoutcomment(cmd)
        self.bytecode = self.x.generate_voltage(
            self.port, self.powerMode, self.powerModeMult * self.powerMode,
            self.wmVoltage)

        #self.saveToFile( "frequency_data.csv" )

    def _wmLockStop_fired(self):
        print "Stop: Wavemeterlock"
        self.isRunning = False

        cmd = "cmd=" + self.x.generate_voltage(
            self.port, self.powerMode, self.powerModeMult * self.powerMode,
            0)  #Spannung wieder auf 0
        self.connection.sendwithoutcomment(cmd)
        self.bytecode = self.x.generate_voltage(
            self.port, self.powerMode, self.powerModeMult * self.powerMode, 0)
        self.wmVoltage = 0

        self.wmLockTimer.Stop()

    def saveToFile(self, fileName):

        f = open(fileName, "a")

        f.write("%f \n" % self.wmFrequency)

        f.close()


#---- PUT TOGETHER CHANNEL ----#

    selectionGroup = traitsui.HGroup(traitsui.Item('relockMode'),
                                     #traitsui.Item('set_Relock'),
                                     )

    channelGroup = traitsui.VGroup(
        selectionGroup,
        powerGroup,
        wmGroup,
        wmlockGroup,
        DCgroup,
        manualGroup,
    )

    traits_view = traitsui.View(channelGroup)
class config(HasTraits):
    uuid = traits.Str(desc="UUID")

    # Directories
    working_dir = Directory(mandatory=True,
                            desc="Location of the Nipype working directory")
    base_dir = Directory(
        os.path.abspath('.'),
        mandatory=True,
        desc='Base directory of data. (Should be subject-independent)')
    sink_dir = Directory(mandatory=True,
                         desc="Location where the BIP will store the results")
    crash_dir = Directory(mandatory=False,
                          desc="Location to store crash files")
    surf_dir = Directory(
        desc="freesurfer directory. subject id's should be the same")
    save_script_only = traits.Bool(False)
    # Execution
    run_using_plugin = Bool(
        False,
        usedefault=True,
        desc="True to run pipeline with plugin, False to run serially")
    plugin = traits.Enum("PBS",
                         "MultiProc",
                         "SGE",
                         "Condor",
                         usedefault=True,
                         desc="plugin to use, if run_using_plugin=True")
    plugin_args = traits.Dict({"qsub_args": "-q many"},
                              usedefault=True,
                              desc='Plugin arguments.')
    test_mode = Bool(
        False,
        mandatory=False,
        usedefault=True,
        desc='Affects whether where and if the workflow keeps its \
                            intermediary files. True to keep intermediary files. '
    )
    # Data
    datagrabber = traits.Instance(Data, ())
    #subject_id = traits.String()
    #contrast = traits.File()
    #mask_contrast = traits.File()
    use_contrast_mask = traits.Bool(True)
    #reg_file = traits.File()
    #mean_image = traits.File()
    background_thresh = traits.Float(0.5)
    hemi = traits.List(['lh', 'rh'])
    roi = traits.List(
        ['superiortemporal', 'bankssts'],
        traits.Enum('superiortemporal', 'bankssts', 'caudalanteriorcingulate',
                    'caudalmiddlefrontal', 'corpuscallosum', 'cuneus',
                    'entorhinal', 'fusiform', 'inferiorparietal',
                    'inferiortemporal', 'isthmuscingulate', 'lateraloccipital',
                    'lateralorbitofrontal', 'lingual', 'medialorbitofrontal',
                    'middletemporal', 'parahippocampal', 'paracentral',
                    'parsopercularis', 'parsorbitalis', 'parstriangularis',
                    'pericalcarine', 'postcentral', 'posteriorcingulate',
                    'precentral', 'precuneus', 'rostralanteriorcingulate',
                    'rostralmiddlefrontal', 'superiorfrontal',
                    'superiorparietal', 'supramarginal', 'frontalpole',
                    'temporalpole', 'transversetemporal', 'insula'),
        usedefault=True)  #35 freesurfer regions,
    thresh = traits.Float(1.5)
예제 #8
0
파일: axes.py 프로젝트: kif/hyperspy
class DataAxis(t.HasTraits):
    name = t.Str()
    units = t.Str()
    scale = t.Float()
    offset = t.Float()
    size = t.CInt()
    low_value = t.Float()
    high_value = t.Float()
    value = t.Range('low_value', 'high_value')
    low_index = t.Int(0)
    high_index = t.Int()
    slice = t.Instance(slice)
    navigate = t.Bool(t.Undefined)
    index = t.Range('low_index', 'high_index')
    axis = t.Array()
    continuous_value = t.Bool(False)

    def __init__(self,
                 size,
                 index_in_array=None,
                 name=t.Undefined,
                 scale=1.,
                 offset=0.,
                 units=t.Undefined,
                 navigate=t.Undefined):
        super(DataAxis, self).__init__()
        self.name = name
        self.units = units
        self.scale = scale
        self.offset = offset
        self.size = size
        self.high_index = self.size - 1
        self.low_index = 0
        self.index = 0
        self.update_axis()
        self.navigate = navigate
        self.axes_manager = None
        self.on_trait_change(self.update_axis, ['scale', 'offset', 'size'])
        self.on_trait_change(self.update_value, 'index')
        self.on_trait_change(self.set_index_from_value, 'value')
        self.on_trait_change(self._update_slice, 'navigate')
        self.on_trait_change(self.update_index_bounds, 'size')
        # The slice must be updated even if the default value did not
        # change to correctly set its value.
        self._update_slice(self.navigate)

    @property
    def index_in_array(self):
        if self.axes_manager is not None:
            return self.axes_manager._axes.index(self)
        else:
            raise AttributeError(
                "This DataAxis does not belong to an AxesManager"
                " and therefore its index_in_array attribute "
                " is not defined")

    @property
    def index_in_axes_manager(self):
        if self.axes_manager is not None:
            return self.axes_manager._get_axes_in_natural_order().\
                   index(self)
        else:
            raise AttributeError(
                "This DataAxis does not belong to an AxesManager"
                " and therefore its index_in_array attribute "
                " is not defined")

    def _get_positive_index(self, index):
        if index < 0:
            index = self.size + index
            if index < 0:
                raise IndexError("index out of bounds")
        return index

    def _get_index(self, value):
        if isinstance(value, float):
            return self.value2index(value)
        else:
            return value

    def _slice_me(self, slice_):
        """Returns a slice to slice the corresponding data axis and 
        change the offset and scale of the DataAxis acordingly.
        
        Parameters
        ----------
        slice_ : {float, int, slice}
        
        Returns
        -------
        my_slice : slice
        
        """
        i2v = self.index2value
        v2i = self.value2index

        if isinstance(slice_, slice):
            start = slice_.start
            stop = slice_.stop
            step = slice_.step
        else:
            if isinstance(slice_, float):
                start = v2i(slice_)
            else:
                start = self._get_positive_index(slice_)
            stop = start + 1
            step = None

        if isinstance(step, float):
            step = int(round(step / self.scale))
        if isinstance(start, float):
            try:
                start = v2i(start)
            except ValueError:
                # The value is below the axis limits
                # we slice from the start.
                start = None
        if isinstance(stop, float):
            try:
                stop = v2i(stop)
            except ValueError:
                # The value is above the axes limits
                # we slice up to the end.
                stop = None

        if step == 0:
            raise ValueError("slice step cannot be zero")

        my_slice = slice(start, stop, step)

        if start is None:
            if step > 0 or step is None:
                start = 0
            else:
                start = self.size - 1
        self.offset = i2v(start)
        if step is not None:
            self.scale *= step

        return my_slice

    def _get_name(self):
        name = (self.name if self.name is not t.Undefined else
                ("Unnamed " + ordinal(self.index_in_axes_manager)))
        return name

    def __repr__(self):
        text = '<%s axis, size: %i' % (
            self._get_name(),
            self.size,
        )
        if self.navigate is True:
            text += ", index: %i" % self.index
        text += ">"
        return text

    def __str__(self):
        return self._get_name() + " axis"

    def connect(self, f, trait='value'):
        self.on_trait_change(f, trait)

    def disconnect(self, f, trait='value'):
        self.on_trait_change(f, trait, remove=True)

    def update_index_bounds(self):
        self.high_index = self.size - 1

    def update_axis(self):
        self.axis = generate_axis(self.offset, self.scale, self.size)
        if len(self.axis) != 0:
            self.low_value, self.high_value = (self.axis.min(),
                                               self.axis.max())

    def _update_slice(self, value):
        if value is False:
            self.slice = slice(None)
        else:
            self.slice = None

    def get_axis_dictionary(self):
        adict = {
            'name': self.name,
            'scale': self.scale,
            'offset': self.offset,
            'size': self.size,
            'units': self.units,
            'index_in_array': self.index_in_array,
            'navigate': self.navigate
        }
        return adict

    def copy(self):
        return DataAxis(**self.get_axis_dictionary())

    def update_value(self):
        self.value = self.axis[self.index]

    def value2index(self, value, rounding=round):
        """Return the closest index to the given value if between the limit.

        Parameters
        ----------
        value : float

        Returns
        -------
        int

        Raises
        ------
        ValueError if value is out of the axis limits.

        """
        if value is None:
            return None
        else:
            index = int(rounding((value - self.offset) / self.scale))
            if self.size > index >= 0:
                return index
            else:
                raise ValueError("The value is out of the axis limits")

    def index2value(self, index):
        return self.axis[index]

    def set_index_from_value(self, value):
        self.index = self.value2index(value)
        # If the value is above the limits we must correct the value
        if self.continuous_value is False:
            self.value = self.index2value(self.index)

    def calibrate(self, value_tuple, index_tuple, modify_calibration=True):
        scale = (value_tuple[1] - value_tuple[0]) /\
        (index_tuple[1] - index_tuple[0])
        offset = value_tuple[0] - scale * index_tuple[0]
        if modify_calibration is True:
            self.offset = offset
            self.scale = scale
        else:
            return offset, scale
예제 #9
0
class HistoryEditor(EditorFactory):
    """
    Progress bar running between 0 and 1 by default
    """
    label = 'history'
    var = tr.Str('t')
    min_value = tr.Float(0)
    max_value = tr.Float(1)
    step = tr.Float(0.01)
    min_var = tr.Str('')
    max_var = tr.Str('')
    step = tr.Str('')

    tooltip = tr.Property(depends_on='time_var, time_max_var')
    @tr.cached_property
    def _get_tooltip(self):
        return 'history slider 0 -> %s -> %s' % (self.var, self.max_var)

    t_min = tr.Property
    def _get_t_min(self):
        if self.min_var == '':
            t_min = self.min_value
        else:
            t_min = getattr(self.model, str(self.min_var))
        return t_min

    t_max = tr.Property
    def _get_t_max(self):
        if self.max_var == '':
            t_max = self.max_value
        else:
            t_max = getattr(self.model, str(self.max_var))
        return t_max

    step = tr.Property
    def _get_step(self):
        if self.step == '':
            step = self.step
        else:
            step = getattr(self.model, str(self.step))
        return step

    def render(self):
        history_bar_widgets = []
        eta = (getattr(self.model, str(self.var)) - self.t_min) / (self.t_max - self.t_min)
        history_slider = ipw.FloatSlider(
            value=eta,
            min=0,
            max=1,
            step=0.01,
            tooltip=self.tooltip,
            continuous_update=False,
            description=self.label,
            disabled=self.disabled,
            # readout=self.readout,
            # readout_format=self.readout_format
            layout = ipw.Layout(display='flex', width="100%")
        )

        def change_time_var(event):
            eta = event['new']
            # with bu.print_output:
            #     print('slider on',self.model)
            t = self.t_min + (self.t_max - self.t_min) * eta
            setattr(self.model, self.var, t)
            app_window = self.controller.app_window
            app_window.update_plot(self.model)

        history_slider.observe(change_time_var,'value')

        # if self.min_var != '':
        #     def change_t_min(event):
        #         t_min = event.new
        #         history_slider.min = t_min
        #     self.model.observe(change_t_min, self.min_var)
        #
        # if self.max_var != '':
        #     def change_t_max(event):
        #         t_max = event.new
        #         history_slider.max = t_max
        #     self.model.observe(change_t_max, self.max_var)

        history_bar_widgets.append(history_slider)
        history_box = ipw.HBox(history_bar_widgets,
                                layout=ipw.Layout(padding='0px'))
        history_box.layout.align_items = 'center'
        return history_box
예제 #10
0
class EnsembleTrainer(t.HasStrictTraits):
    def __init__(self, config={}, **kwargs):
        trainer_template = Trainer(**config)
        super().__init__(trainer_template=trainer_template,
                         config=config,
                         **kwargs)

    config: dict = t.Dict()

    trainer_template: Trainer = t.Instance(Trainer)
    trainers: ty.List[Trainer] = t.List(t.Instance(Trainer))

    n_folds = t.Int(5)

    dl_test: DataLoader = t.DelegatesTo("trainer_template")
    data_spec: dict = t.DelegatesTo("trainer_template")
    cuda: bool = t.DelegatesTo("trainer_template")
    device: str = t.DelegatesTo("trainer_template")
    loss_func: str = t.DelegatesTo("trainer_template")
    batch_size: int = t.DelegatesTo("trainer_template")
    win_len: int = t.DelegatesTo("trainer_template")
    has_null_class: bool = t.DelegatesTo("trainer_template")
    predict_null_class: bool = t.DelegatesTo("trainer_template")
    name: str = t.Str()

    def _name_default(self):
        import time

        modelstr = "Ensemble"
        timestr = time.strftime("%Y%m%d-%H%M%S")
        return f"{modelstr}_{timestr}"

    X_folds = t.Tuple(transient=True)
    ys_folds = t.Tuple(transient=True)

    def _trainers_default(self):
        # Temp trainer for grabbing datasets, etc
        tt = self.trainer_template
        tt.init_data()

        # Combine official train & val sets
        X = torch.cat(
            [tt.dl_train.dataset.tensors[0], tt.dl_val.dataset.tensors[0]])
        ys = [
            torch.cat([yt, yv]) for yt, yv in zip(
                tt.dl_train.dataset.tensors[1:], tt.dl_val.dataset.tensors[1:])
        ]
        # make folds
        fold_len = int(np.ceil(len(X) / self.n_folds))
        self.X_folds = torch.split(X, fold_len)
        self.ys_folds = [torch.split(y, fold_len) for y in ys]

        trainers = []
        for i_val_fold in range(self.n_folds):
            trainer = Trainer(
                validation_fold=i_val_fold,
                name=f"{self.name}/{i_val_fold}",
                **self.config,
            )

            trainer.dl_test = tt.dl_test

            trainers.append(trainer)

        return trainers

    model: models.BaseNet = t.Instance(torch.nn.Module, transient=True)

    def _model_default(self):
        model = models.FilterNetEnsemble()
        model.set_models([trainer.model for trainer in self.trainers])
        return model

    model_path: str = t.Str()

    def _model_path_default(self):
        return f"saved_models/{self.name}/"

    def init_data(self):
        # Initiate loading of datasets, model
        pass
        # for trainer in self.trainers:
        #     trainer.init_data()

    def init_train(self):
        pass
        # for trainer in self.trainers:
        #     trainer.init_train()

    def train(self, max_epochs=50):
        """ A pretty standard training loop, constrained to stop in `max_epochs` but may stop early if our
        custom stopping metric does not improve for `self.patience` epochs. Always checkpoints
        when a new best stopping_metric is achieved. An alternative to using
        ray.tune for training."""

        for trainer in self.trainers:
            # Add data to trainer

            X_train = torch.cat([
                arr for i, arr in enumerate(self.X_folds)
                if i != trainer.validation_fold
            ])
            ys_train = [
                torch.cat([
                    arr for i, arr in enumerate(y)
                    if i != trainer.validation_fold
                ]) for y in self.ys_folds
            ]

            X_val = torch.cat([
                arr for i, arr in enumerate(self.X_folds)
                if i == trainer.validation_fold
            ])
            ys_val = [
                torch.cat([
                    arr for i, arr in enumerate(y)
                    if i == trainer.validation_fold
                ]) for y in self.ys_folds
            ]

            trainer.dl_train = DataLoader(
                TensorDataset(torch.Tensor(X_train), *ys_train),
                batch_size=trainer.batch_size,
                shuffle=True,
            )
            trainer.data_spec = self.trainer_template.data_spec
            trainer.epoch_iters = self.trainer_template.epoch_iters
            trainer.dl_val = DataLoader(
                TensorDataset(torch.Tensor(X_val), *ys_val),
                batch_size=trainer.batch_size,
                shuffle=False,
            )

            # Now clear local vars to save ranm
            X_train = ys_train = X_val = ys_val = None

            trainer.init_data()
            trainer.init_train()
            trainer.train(max_epochs=max_epochs)

            # Clear trainer train and val datasets to save ram
            trainer.dl_train = t.Undefined
            trainer.dl_val = t.Undefined

            print(f"RESTORING TO best model")
            trainer._restore()
            trainer._save()

            trainer.print_train_summary()

            em = EvalModel(trainer=trainer)

            em.run_test_set()
            em.calc_metrics()
            em.calc_ward_metrics()
            print(em.classification_report_df.to_string(float_format="%.3f"))
            em._save()

    def print_train_summary(self):
        for trainer in self.trainers:
            trainer.print_train_summary()

    def _save(self, checkpoint_dir=None, save_model=True, save_trainer=True):
        """ Saves/checkpoints model state and training state to disk. """
        if checkpoint_dir is None:
            checkpoint_dir = self.model_path
        else:
            self.model_path = checkpoint_dir

        os.makedirs(checkpoint_dir, exist_ok=True)

        # save model params
        model_path = os.path.join(checkpoint_dir, "model.pth")
        trainer_path = os.path.join(checkpoint_dir, "trainer.pth")

        if save_model:
            torch.save(self.model.state_dict(), model_path)
        if save_trainer:
            with open(trainer_path, "wb") as f:
                pickle.dump(self, f)

        return checkpoint_dir

    def _restore(self, checkpoint_dir=None):
        """ Restores model state and training state from disk. """

        if checkpoint_dir is None:
            checkpoint_dir = self.model_path

        model_path = os.path.join(checkpoint_dir, "model.pth")
        trainer_path = os.path.join(checkpoint_dir, "trainer.pth")

        # Reconstitute old trainer and copy state to this trainer.
        with open(trainer_path, "rb") as f:
            other_trainer = pickle.load(f)

        self.__setstate__(other_trainer.__getstate__())

        # Load sub-models
        for trainer in self.trainers:
            trainer._restore()

        # Load model (after loading state in case we need to re-initialize model from config)
        self.model.load_state_dict(
            torch.load(model_path, map_location=self.device))
예제 #11
0
class EvalModel(t.HasStrictTraits):
    trainer: Trainer = t.Any()
    model: mo.BaseNet = t.DelegatesTo("trainer")
    dl_test: DataLoader = t.DelegatesTo("trainer")
    data_spec: dict = t.DelegatesTo("trainer")
    cuda: bool = t.DelegatesTo("trainer")
    device: str = t.DelegatesTo("trainer")
    loss_func: str = t.DelegatesTo("trainer")
    model_path: str = t.DelegatesTo("trainer")
    has_null_class: bool = t.DelegatesTo("trainer")
    predict_null_class: bool = t.DelegatesTo("trainer")

    # 'prediction' mode employs overlap and reconstructs signal
    #   as a contiguous timeseries w/ optional windowing.
    #   It aims for best accuracy/f1 by using overlap, and will
    #   typically outperform 'training' mode.
    # 'training' mode does not average repeated point and does
    #   not window; it should product acc/loss/f1 similar to
    #   training mode.
    run_mode: str = t.Enum(["prediction", "training"])
    window: str = t.Enum(["hanning", "boxcar"])
    eval_batch_size: int = t.Int(100)

    target_names: ty.List[str] = t.ListStr()

    def _target_names_default(self):
        target_names = self.data_spec["output_spec"][0]["classes"]

        if self.has_null_class:
            assert target_names[0] in ("", "Null")

            if not self.predict_null_class:
                target_names = target_names[1:]

        return target_names

    def _run_model_on_batch(self, data, targets):
        targets = torch.stack(targets)

        if self.cuda:
            data, targets = data.cuda(), targets.cuda()

        output = self.model(data)

        _targets = self.model.transform_targets(targets, one_hot=False)
        if self.loss_func == "cross_entropy":
            _losses = [F.cross_entropy(o, t) for o, t in zip(output, _targets)]
            loss = sum(_losses)
        elif self.loss_func == "binary_cross_entropy":
            _targets_onehot = self.model.transform_targets(targets,
                                                           one_hot=True)
            _losses = [
                F.binary_cross_entropy_with_logits(o, t)
                for o, t in zip(output, _targets_onehot)
            ]
            loss = sum(_losses)
        else:
            raise NotImplementedError(self.loss)

        # Assume only 1 output:

        return loss, output[0], _targets[0], _losses[0]

    def run_test_set(self, dl=None):
        """ Runs `self.model` on `self.dl_test` (or a provided dl) and stores results for subsequent evaluation. """
        if dl is None:
            dl = self.dl_test

        if self.cuda:
            self.model.cuda()
        self.model.eval()
        if self.eval_batch_size:
            dl = DataLoader(dl.dataset,
                            batch_size=self.eval_batch_size,
                            shuffle=False)
        #
        #     # Xc, yc = data.get_x_y_contig('test')
        X, *ys = dl.dataset.tensors
        # X: [N, input_chans, win_len]
        step = int(X.shape[2] / 2)
        assert torch.equal(X[0, :, step], X[1, :, 0])

        losses = []
        outputsraw = []
        outputs = []
        targets = []

        with Timer("run", log_output=False) as tr:
            with Timer("infer", log_output=False) as ti:
                for batch_idx, (data, *target) in enumerate(dl):
                    (
                        batch_loss,
                        batch_output,
                        batch_targets,
                        train_losses,
                    ) = self._run_model_on_batch(data, target)

                    losses.append(batch_loss.detach().cpu().item())
                    outputsraw.append(batch_output.detach().cpu().data.numpy())
                    outputs.append(
                        torch.argmax(batch_output, 1,
                                     False).detach().cpu().data.numpy())
                    targets.append(batch_targets.detach().cpu().data.numpy())
            self.infer_time_s_cpu = ti.interval_cpu
            self.infer_time_s_wall = ti.interval_wall

            self.loss = np.mean(losses)
            targets = np.concatenate(targets, axis=0)  # [N, out_win_len]
            outputsraw = np.concatenate(
                outputsraw, axis=0)  # [N, n_out_classes, out_win_len]
            outputs = np.concatenate(outputs,
                                     axis=0)  # [N, n_out_classes, out_win_len]

            # win_len = toutputsraw[0].shape[-1]
            if (self.model.output_type == "many_to_one_takelast"
                    or self.run_mode == "training"):
                self.targets = np.concatenate(targets, axis=-1)  # [N,]
                self.outputsraw = np.concatenate(
                    outputsraw, axis=-1)  # [n_out_classes, N,]
                self.outputs = np.concatenate(outputs, axis=-1)  # [N,]

            elif self.run_mode == "prediction":
                n_segments, n_classes, out_win_len = outputsraw.shape

                output_step = int(out_win_len / 2)

                if self.window == "hanning":
                    EPS = 0.001  # prevents divide-by-zero
                    arr_window = (1 - EPS) * np.hanning(out_win_len) + EPS
                elif self.window == "boxcar":
                    arr_window = np.ones((out_win_len, ))
                else:
                    raise ValueError()

                # Allocate space for merged predictions
                if self.has_null_class and not self.predict_null_class:
                    outputsraw2 = np.zeros(
                        (n_segments + 1, n_classes - 1, output_step, 2))
                    window2 = np.zeros(
                        (n_segments + 1, n_classes - 1, output_step,
                         2))  # [N+1, out_win_len/2, 2]
                    # Drop in outputs/window vals in the two layers
                    outputsraw = outputsraw[:, 1:, :]
                else:
                    outputsraw2 = np.zeros(
                        (n_segments + 1, n_classes, output_step, 2))
                    window2 = np.zeros((n_segments + 1, n_classes, output_step,
                                        2))  # [N+1, out_win_len/2, 2]

                # Drop in outputs/window vals in the two layers
                outputsraw2[:-1, :, :, 0] = outputsraw[:, :, :output_step]
                outputsraw2[1:, :, :,
                            1] = outputsraw[:, :, output_step:output_step * 2]
                window2[:-1, :, :, 0] = arr_window[:output_step]
                window2[1:, :, :, 1] = arr_window[output_step:output_step * 2]

                merged_outputsraw = (outputsraw2 * window2).sum(
                    axis=3) / (window2).sum(axis=3)
                softmaxed_merged_outputsraw = softmax(merged_outputsraw,
                                                      axis=1)
                merged_outputs = np.argmax(softmaxed_merged_outputsraw, 1)

                self.outputsraw = np.concatenate(merged_outputsraw, axis=-1)
                self.outputs = np.concatenate(merged_outputs, axis=-1)
                self.targets = np.concatenate(
                    np.concatenate(
                        [
                            targets[:, :output_step],
                            targets[[-1], output_step:output_step * 2],
                        ],
                        axis=0,
                    ),
                    axis=-1,
                )
            else:
                raise ValueError()

        if self.has_null_class and not self.predict_null_class:
            not_null_mask = self.targets > 0
            self.outputsraw = self.outputsraw[..., not_null_mask]
            self.outputs = self.outputs[not_null_mask]
            self.targets = self.targets[not_null_mask]
            self.targets -= 1

        self.n_samples_in = np.prod(dl.dataset.tensors[1].shape)
        self.n_samples_out = len(self.outputs)
        self.infer_samples_per_s = self.n_samples_in / self.infer_time_s_wall
        self.run_time_s_cpu = tr.interval_cpu
        self.run_time_s_wall = tr.interval_wall

    loss: float = t.Float()
    targets: np.ndarray = t.Array()
    outputsraw: np.ndarray = t.Array()
    outputs: np.ndarray = t.Array()
    n_samples_in: int = t.Int()
    n_samples_out: int = t.Int()
    infer_samples_per_s: float = t.Float()

    infer_time_s_cpu: float = t.Float()
    infer_time_s_wall: float = t.Float()
    run_time_s_cpu: float = t.Float()
    run_time_s_wall: float = t.Float()

    extra: dict = t.Dict({})

    acc: float = t.Float()
    f1: float = t.Float()
    f1_mean: float = t.Float()
    event_f1: float = t.Float()
    classification_report_txt: str = t.Str()
    classification_report_dict: dict = t.Dict()
    classification_report_df: pd.DataFrame = t.Property(
        t.Instance(pd.DataFrame))
    confusion_matrix: np.ndarray = t.Array()

    nonull_acc: float = t.Float()
    nonull_f1: float = t.Float()
    nonull_f1_mean: float = t.Float()
    nonull_classification_report_txt: str = t.Str()
    nonull_classification_report_dict: dict = t.Dict()
    nonull_classification_report_df: pd.DataFrame = t.Property(
        t.Instance(pd.DataFrame))
    nonull_confusion_matrix: np.ndarray = t.Array()

    def calc_metrics(self):

        self.acc = sklearn.metrics.accuracy_score(self.targets, self.outputs)
        self.f1 = sklearn.metrics.f1_score(self.targets,
                                           self.outputs,
                                           average="weighted")
        self.f1_mean = sklearn.metrics.f1_score(self.targets,
                                                self.outputs,
                                                average="macro")

        self.classification_report_txt = sklearn.metrics.classification_report(
            self.targets,
            self.outputs,
            digits=3,
            labels=np.arange(len(self.target_names)),
            target_names=self.target_names,
        )
        self.classification_report_dict = sklearn.metrics.classification_report(
            self.targets,
            self.outputs,
            digits=3,
            output_dict=True,
            labels=np.arange(len(self.target_names)),
            target_names=self.target_names,
        )
        self.confusion_matrix = sklearn.metrics.confusion_matrix(
            self.targets, self.outputs)

        # Now, ignoring the null/none class:
        if self.has_null_class and self.predict_null_class:
            # assume null class comes fistnonull_mask = self.targets > 0
            nonull_mask = self.targets > 0
            nonull_targets = self.targets[nonull_mask]
            # nonull_outputs = self.outputs[nonull_mask]
            nonull_outputs = self.outputsraw[1:, :].argmax(
                axis=0)[nonull_mask] + 1

            self.nonull_acc = sklearn.metrics.accuracy_score(
                nonull_targets, nonull_outputs)
            self.nonull_f1 = sklearn.metrics.f1_score(nonull_targets,
                                                      nonull_outputs,
                                                      average="weighted")
            self.nonull_f1_mean = sklearn.metrics.f1_score(nonull_targets,
                                                           nonull_outputs,
                                                           average="macro")
            self.nonull_classification_report_txt = sklearn.metrics.classification_report(
                nonull_targets,
                nonull_outputs,
                digits=3,
                labels=np.arange(len(self.target_names)),
                target_names=self.target_names,
            )
            self.nonull_classification_report_dict = sklearn.metrics.classification_report(
                nonull_targets,
                nonull_outputs,
                digits=3,
                output_dict=True,
                labels=np.arange(len(self.target_names)),
                target_names=self.target_names,
            )
            self.nonull_confusion_matrix = sklearn.metrics.confusion_matrix(
                nonull_targets, nonull_outputs)
        else:
            self.nonull_acc = self.acc
            self.nonull_f1 = self.f1
            self.nonull_f1_mean = self.f1_mean
            self.nonull_classification_report_txt = self.classification_report_txt
            self.nonull_classification_report_dict = self.classification_report_dict
            self.nonull_confusion_matrix = self.confusion_matrix

    ward_metrics: WardMetrics = t.Instance(WardMetrics)

    def calc_ward_metrics(self):
        """ Do event-wise metrics, using the `wardmetrics` package which implements metrics from:

         [1]    J. A. Ward, P. Lukowicz, and H. W. Gellersen, “Performance metrics for activity recognition,”
                    ACM Trans. Intell. Syst. Technol., vol. 2, no. 1, pp. 1–23, Jan. 2011.
        """

        import wardmetrics

        # Must be in prediction mode -- otherwise, data is not contiguous, ward metrics will be bogus
        assert self.run_mode == "prediction"

        targets = self.targets
        predictions = self.outputs

        wmetrics = WardMetrics()

        targets_events = wardmetrics.frame_results_to_events(targets)
        preds_events = wardmetrics.frame_results_to_events(predictions)

        for i, class_name in enumerate(self.target_names):
            class_wmetrics = ClassWardMetrics()

            t = targets_events.get(str(i), [])
            p = preds_events.get(str(i), [])
            # class_wmetrics['t'] = t
            # class_wmetrics['p'] = p

            try:
                assert len(t) and len(p)
                (
                    twoset_results,
                    segments_with_scores,
                    segment_counts,
                    normed_segment_counts,
                ) = wardmetrics.eval_segments(t, p)
                class_wmetrics.segment_twoset_results = twoset_results

                (
                    gt_event_scores,
                    det_event_scores,
                    detailed_scores,
                    standard_scores,
                ) = wardmetrics.eval_events(t, p)
                class_wmetrics.event_detailed_scores = detailed_scores
                class_wmetrics.event_standard_scores = standard_scores
            except (AssertionError, ZeroDivisionError) as e:
                class_wmetrics.segment_twoset_results = {}
                class_wmetrics.event_detailed_scores = {}
                class_wmetrics.event_standard_scores = {}
                # print("Empty Results or targets for a class.")
                # raise ValueError("Empty Results or targets for a class.")

            wmetrics.class_ward_metrics.append(class_wmetrics)

        tt = []
        pp = []
        for i, class_name in enumerate(self.target_names):
            # skip null class for combined eventing:
            if class_name in ("", "Null"):
                continue

            if len(tt) or len(pp):
                offset = np.max(tt + pp) + 2
            else:
                offset = 0
            [(a + offset, b + offset) for (a, b) in t]

            t = targets_events.get(str(i), [])
            p = preds_events.get(str(i), [])

            tt += [(a + offset, b + offset) for (a, b) in t]
            pp += [(a + offset, b + offset) for (a, b) in p]

        t = tt
        p = pp

        class_wmetrics = ClassWardMetrics()
        assert len(t) and len(p)
        (
            twoset_results,
            segments_with_scores,
            segment_counts,
            normed_segment_counts,
        ) = wardmetrics.eval_segments(t, p)
        class_wmetrics.segment_twoset_results = twoset_results

        (
            gt_event_scores,
            det_event_scores,
            detailed_scores,
            standard_scores,
        ) = wardmetrics.eval_events(t, p)
        class_wmetrics.event_detailed_scores = detailed_scores
        class_wmetrics.event_standard_scores = standard_scores

        # Reformat as dataframe for easier calculations
        df = pd.DataFrame(
            [cm.event_standard_scores for cm in wmetrics.class_ward_metrics],
            index=self.target_names,
        )
        df.loc["all_nonull"] = class_wmetrics.event_standard_scores

        # Calculate F1's to summarize recall/precision for each class
        df["f1"] = (2 * (df["precision"] * df["recall"]) /
                    (df["precision"] + df["recall"]))
        df["f1 (weighted)"] = (
            2 * (df["precision (weighted)"] * df["recall (weighted)"]) /
            (df["precision (weighted)"] + df["recall (weighted)"]))

        # Load dataframes into dictionary output
        wmetrics.df_event_scores = df
        wmetrics.df_event_detailed_scores = pd.DataFrame(
            [cm.event_detailed_scores for cm in wmetrics.class_ward_metrics],
            index=self.target_names,
        )
        wmetrics.df_segment_2set_results = pd.DataFrame(
            [cm.segment_twoset_results for cm in wmetrics.class_ward_metrics],
            index=self.target_names,
        )
        wmetrics.overall_ward_metrics = class_wmetrics

        self.ward_metrics = wmetrics
        self.event_f1 = self.ward_metrics.df_event_scores.loc["all_nonull",
                                                              "f1"]

    def _get_classification_report_df(self):
        df = pd.DataFrame(self.classification_report_dict).T

        # Include Ward-metrics-derived "Event F1 (unweighted by length)"
        if self.ward_metrics:
            df["event_f1"] = self.ward_metrics.df_event_scores["f1"]
        else:
            df["event_f1"] = np.nan

            # Calculate various summary averages
        df.loc["macro avg", "event_f1"] = df["event_f1"].iloc[:-3].mean()
        df.loc["weighted avg", "event_f1"] = (
            df["event_f1"].iloc[:-3] *
            df["support"].iloc[:-3]).sum() / df["support"].iloc[:-3].sum()

        df["support"] = df["support"].astype(int)

        return df

    def _get_nonull_classification_report_df(self):
        target_names = self.target_names
        if not (target_names[0] in ("", "Null")):
            return None

        df = pd.DataFrame(self.nonull_classification_report_dict).T

        df["support"] = df["support"].astype(int)

        return df

    def _save(self, checkpoint_dir=None):
        """ Saves/checkpoints model state and training state to disk. """
        if checkpoint_dir is None:
            checkpoint_dir = self.model_path

        os.makedirs(checkpoint_dir, exist_ok=True)

        # save model params
        evalmodel_path = os.path.join(checkpoint_dir, "evalmodel.pth")

        with open(evalmodel_path, "wb") as f:
            pickle.dump(self, f)

        return checkpoint_dir

    def _restore(self, checkpoint_dir=None):
        """ Restores model state and training state from disk. """

        if checkpoint_dir is None:
            checkpoint_dir = self.model_path

        evalmodel_path = os.path.join(checkpoint_dir, "evalmodel.pth")

        # Reconstitute old trainer and copy state to this trainer.
        with open(evalmodel_path, "rb") as f:
            other_evalmodel = pickle.load(f)

        self.__setstate__(other_evalmodel.__getstate__())

        self.trainer._restore(checkpoint_dir)
예제 #12
0
class Parameter(t.HasTraits):
    """Model parameter

    Attributes
    ----------
    value : float or array
        The value of the parameter for the current location. The value
        for other locations is stored in map.
    bmin, bmax: float
        Lower and upper bounds of the parameter value.
    twin : {None, Parameter}
        If it is not None, the value of the current parameter is
        a function of the given Parameter. The function is by default
        the identity function, but it can be defined by twin_function
    twin_function : function
        Function that, if selt.twin is not None, takes self.twin.value
        as its only argument and returns a float or array that is
        returned when getting Parameter.value
    twin_inverse_function : function
        The inverse of twin_function. If it is None then it is not
        possible to set the value of the parameter twin by setting
        the value of the current parameter.
    ext_force_positive : bool
        If True, the parameter value is set to be the absolute value
        of the input value i.e. if we set Parameter.value = -3, the
        value stored is 3 instead. This is useful to bound a value
        to be positive in an optimization without actually using an
        optimizer that supports bounding.
    ext_bounded : bool
        Similar to ext_force_positive, but in this case the bounds are
        defined by bmin and bmax. It is a better idea to use
        an optimizer that supports bounding though.

    Methods
    -------
    as_signal(field = 'values')
        Get a parameter map as a signal object
    plot()
        Plots the value of the Parameter at all locations.
    export(folder=None, name=None, format=None, save_std=False)
        Saves the value of the parameter map to the specified format
    connect, disconnect(function)
        Call the functions connected when the value attribute changes.

    """
    __number_of_elements = 1
    __value = 0
    __free = True
    _bounds = (None, None)
    __twin = None
    _axes_manager = None
    __ext_bounded = False
    __ext_force_positive = False

    # traitsui bugs out trying to make an editor for this, so always specify!
    # (it bugs out, because both editor shares the object, and Array editors
    # don't like non-sequence objects). TextEditor() works well, so does
    # RangeEditor() as it works with bmin/bmax.
    value = t.Property(t.Either([t.CFloat(0), Array()]))

    units = t.Str('')
    free = t.Property(t.CBool(True))

    bmin = t.Property(NoneFloat(), label="Lower bounds")
    bmax = t.Property(NoneFloat(), label="Upper bounds")

    def __init__(self):
        self._twins = set()
        self.events = Events()
        self.events.value_changed = Event("""
            Event that triggers when the `Parameter.value` changes.

            The event triggers after the internal state of the `Parameter` has
            been updated.

            Arguments
            ---------
            obj : Parameter
                The `Parameter` that the event belongs to
            value : {float | array}
                The new value of the parameter
            """,
                                          arguments=["obj", 'value'])
        self.twin_function = lambda x: x
        self.twin_inverse_function = lambda x: x
        self.std = None
        self.component = None
        self.grad = None
        self.name = ''
        self.units = ''
        self.map = None
        self.model = None
        self._whitelist = {
            '_id_name': None,
            'value': None,
            'std': None,
            'free': None,
            'units': None,
            'map': None,
            '_bounds': None,
            'ext_bounded': None,
            'name': None,
            'ext_force_positive': None,
            'self': ('id', None),
            'twin_function': ('fn', None),
            'twin_inverse_function': ('fn', None),
        }
        self._slicing_whitelist = {'map': 'inav'}

    def _load_dictionary(self, dictionary):
        """Load data from dictionary

        Parameters
        ----------
        dict : dictionary
            A dictionary containing at least the following items:
            _id_name : string
                _id_name of the original parameter, used to create the
                dictionary. Has to match with the self._id_name
            _whitelist : dictionary
                a dictionary, which keys are used as keywords to match with the
                parameter attributes.  For more information see
                :meth:`hyperspy.misc.export_dictionary.load_from_dictionary`
            * any field from _whitelist.keys() *
        Returns
        -------
        id_value : int
            the ID value of the original parameter, to be later used for setting
            up the correct twins

        """
        if dictionary['_id_name'] == self._id_name:
            load_from_dictionary(self, dictionary)
            return dictionary['self']
        else:
            raise ValueError(
                "_id_name of parameter and dictionary do not match, \nparameter._id_name = %s\
                    \ndictionary['_id_name'] = %s" %
                (self._id_name, dictionary['_id_name']))

    def __repr__(self):
        text = ''
        text += 'Parameter %s' % self.name
        if self.component is not None:
            text += ' of %s' % self.component._get_short_description()
        text = '<' + text + '>'
        return text

    def __len__(self):
        return self._number_of_elements

    def connect(self, f):
        warnings.warn(
            "The method `Parameter.connect()` has been deprecated and will be "
            "removed in HyperSpy 0.10. Please use "
            "`Parameter.events.value_changed.connect()` instead.",
            VisibleDeprecationWarning)
        self.events.value_changed.connect(f, [])

    def disconnect(self, f):
        warnings.warn(
            "The method `Parameter.disconnect()` has been deprecated and will "
            "be removed in HyperSpy 0.10. Please use "
            "`Parameter.events.value_changed.disconnect()` instead.",
            VisibleDeprecationWarning)
        self.events.value_changed.disconnect(f)

    def _get_value(self):
        if self.twin is None:
            return self.__value
        else:
            return self.twin_function(self.twin.value)

    def _set_value(self, value):
        try:
            # Use try/except instead of hasattr("__len__") because a numpy
            # memmap has a __len__ wrapper even for numbers that raises a
            # TypeError when calling. See issue #349.
            if len(value) != self._number_of_elements:
                raise ValueError("The length of the parameter must be ",
                                 self._number_of_elements)
            else:
                if not isinstance(value, tuple):
                    value = tuple(value)
        except TypeError:
            if self._number_of_elements != 1:
                raise ValueError("The length of the parameter must be ",
                                 self._number_of_elements)
        old_value = self.__value

        if self.twin is not None:
            if self.twin_inverse_function is not None:
                self.twin.value = self.twin_inverse_function(value)
            return

        if self.ext_bounded is False:
            self.__value = value
        else:
            if self.ext_force_positive is True:
                value = np.abs(value)
            if self._number_of_elements == 1:
                if self.bmin is not None and value <= self.bmin:
                    self.__value = self.bmin
                elif self.bmax is not None and value >= self.bmax:
                    self.__value = self.bmax
                else:
                    self.__value = value
            else:
                bmin = (self.bmin if self.bmin is not None else -np.inf)
                bmax = (self.bmax if self.bmin is not None else np.inf)
                self.__value = np.clip(value, bmin, bmax)

        if (self._number_of_elements != 1
                and not isinstance(self.__value, tuple)):
            self.__value = tuple(self.__value)
        if old_value != self.__value:
            self.events.value_changed.trigger(value=self.__value, obj=self)
        self.trait_property_changed('value', old_value, self.__value)

    # Fix the parameter when coupled
    def _get_free(self):
        if self.twin is None:
            return self.__free
        else:
            return False

    def _set_free(self, arg):
        old_value = self.__free
        self.__free = arg
        if self.component is not None:
            self.component._update_free_parameters()
        self.trait_property_changed('free', old_value, self.__free)

    def _on_twin_update(self, value, twin=None):
        if (twin is not None and hasattr(twin, 'events')
                and hasattr(twin.events, 'value_changed')):
            with twin.events.value_changed.suppress_callback(
                    self._on_twin_update):
                self.events.value_changed.trigger(value=value, obj=self)
        else:
            self.events.value_changed.trigger(value=value, obj=self)

    def _set_twin(self, arg):
        if arg is None:
            if self.twin is not None:
                # Store the value of the twin in order to set the
                # value of the parameter when it is uncoupled
                twin_value = self.value
                if self in self.twin._twins:
                    self.twin._twins.remove(self)
                    self.twin.events.value_changed.disconnect(
                        self._on_twin_update)

                self.__twin = arg
                self.value = twin_value
        else:
            if self not in arg._twins:
                arg._twins.add(self)
                arg.events.value_changed.connect(self._on_twin_update,
                                                 ["value"])
            self.__twin = arg

        if self.component is not None:
            self.component._update_free_parameters()

    def _get_twin(self):
        return self.__twin

    twin = property(_get_twin, _set_twin)

    def _get_bmin(self):
        if self._number_of_elements == 1:
            return self._bounds[0]
        else:
            return self._bounds[0][0]

    def _set_bmin(self, arg):
        old_value = self.bmin
        if self._number_of_elements == 1:
            self._bounds = (arg, self.bmax)
        else:
            self._bounds = ((arg, self.bmax), ) * self._number_of_elements
        # Update the value to take into account the new bounds
        self.value = self.value
        self.trait_property_changed('bmin', old_value, arg)

    def _get_bmax(self):
        if self._number_of_elements == 1:
            return self._bounds[1]
        else:
            return self._bounds[0][1]

    def _set_bmax(self, arg):
        old_value = self.bmax
        if self._number_of_elements == 1:
            self._bounds = (self.bmin, arg)
        else:
            self._bounds = ((self.bmin, arg), ) * self._number_of_elements
        # Update the value to take into account the new bounds
        self.value = self.value
        self.trait_property_changed('bmax', old_value, arg)

    @property
    def _number_of_elements(self):
        return self.__number_of_elements

    @_number_of_elements.setter
    def _number_of_elements(self, arg):
        # Do nothing if the number of arguments stays the same
        if self.__number_of_elements == arg:
            return
        if arg <= 1:
            raise ValueError("Please provide an integer number equal "
                             "or greater to 1")
        self._bounds = ((self.bmin, self.bmax), ) * arg
        self.__number_of_elements = arg

        if arg == 1:
            self._Parameter__value = 0
        else:
            self._Parameter__value = (0, ) * arg
        if self.component is not None:
            self.component.update_number_parameters()

    @property
    def ext_bounded(self):
        return self.__ext_bounded

    @ext_bounded.setter
    def ext_bounded(self, arg):
        if arg is not self.__ext_bounded:
            self.__ext_bounded = arg
            # Update the value to take into account the new bounds
            self.value = self.value

    @property
    def ext_force_positive(self):
        return self.__ext_force_positive

    @ext_force_positive.setter
    def ext_force_positive(self, arg):
        if arg is not self.__ext_force_positive:
            self.__ext_force_positive = arg
            # Update the value to take into account the new bounds
            self.value = self.value

    def store_current_value_in_array(self):
        """Store the value and std attributes.

        See also
        --------
        fetch, assign_current_value_to_all

        """
        indices = self._axes_manager.indices[::-1]
        # If it is a single spectrum indices is ()
        if not indices:
            indices = (0, )
        self.map['values'][indices] = self.value
        self.map['is_set'][indices] = True
        if self.std is not None:
            self.map['std'][indices] = self.std

    def fetch(self):
        """Fetch the stored value and std attributes.


        See Also
        --------
        store_current_value_in_array, assign_current_value_to_all

        """
        indices = self._axes_manager.indices[::-1]
        # If it is a single spectrum indices is ()
        if not indices:
            indices = (0, )
        if self.map['is_set'][indices]:
            self.value = self.map['values'][indices]
            self.std = self.map['std'][indices]

    def assign_current_value_to_all(self, mask=None):
        """Assign the current value attribute to all the  indices

        Parameters
        ----------
        mask: {None, boolean numpy array}
            Set only the indices that are not masked i.e. where
            mask is False.

        See Also
        --------
        store_current_value_in_array, fetch

        """
        if mask is None:
            mask = np.zeros(self.map.shape, dtype='bool')
        self.map['values'][mask == False] = self.value
        self.map['is_set'][mask == False] = True

    def _create_array(self):
        """Create the map array to store the information in
        multidimensional datasets.

        """
        shape = self._axes_manager._navigation_shape_in_array
        if not shape:
            shape = [
                1,
            ]
        dtype_ = np.dtype([('values', 'float', self._number_of_elements),
                           ('std', 'float', self._number_of_elements),
                           ('is_set', 'bool', 1)])
        if (self.map is None or self.map.shape != shape
                or self.map.dtype != dtype_):
            self.map = np.zeros(shape, dtype_)
            self.map['std'].fill(np.nan)
            # TODO: in the future this class should have access to
            # axes manager and should be able to fetch its own
            # values. Until then, the next line is necessary to avoid
            # erros when self.std is defined and the shape is different
            # from the newly defined arrays
            self.std = None

    def as_signal(self, field='values'):
        """Get a parameter map as a signal object.

        Please note that this method only works when the navigation
        dimension is greater than 0.

        Parameters
        ----------
        field : {'values', 'std', 'is_set'}

        Raises
        ------

        NavigationDimensionError : if the navigation dimension is 0

        """
        from hyperspy.signal import BaseSignal

        s = BaseSignal(data=self.map[field],
                       axes=self._axes_manager._get_navigation_axes_dicts())
        if self.component is not None and \
                self.component.active_is_multidimensional:
            s.data[np.logical_not(self.component._active_array)] = np.nan

        s.metadata.General.title = ("%s parameter" %
                                    self.name if self.component is None else
                                    "%s parameter of %s component" %
                                    (self.name, self.component.name))
        for axis in s.axes_manager._axes:
            axis.navigate = False
        if self._number_of_elements > 1:
            s.axes_manager._append_axis(size=self._number_of_elements,
                                        name=self.name,
                                        navigate=True)
        s._assign_subclass()
        if field == "values":
            # Add the variance if available
            std = self.as_signal(field="std")
            if not np.isnan(std.data).all():
                std.data = std.data**2
                std.metadata.General.title = "Variance"
                s.metadata.set_item("Signal.Noise_properties.variance", std)
        return s

    def plot(self, **kwargs):
        """Plot parameter signal.

        Parameters
        ----------
        **kwargs
            Any extra keyword arguments are passed to the signal plot.

        Example
        -------
        >>> parameter.plot()

        Set the minimum and maximum displayed values

        >>> parameter.plot(vmin=0, vmax=1)
        """
        self.as_signal().plot(**kwargs)

    def export(self, folder=None, name=None, format=None, save_std=False):
        """Save the data to a file.

        All the arguments are optional.

        Parameters
        ----------
        folder : str or None
            The path to the folder where the file will be saved.
             If `None` the current folder is used by default.
        name : str or None
            The name of the file. If `None` the Components name followed
             by the Parameter `name` attributes will be used by default.
              If a file with the same name exists the name will be
              modified by appending a number to the file path.
        save_std : bool
            If True, also the standard deviation will be saved

        """
        if format is None:
            format = preferences.General.default_export_format
        if name is None:
            name = self.component.name + '_' + self.name
        filename = incremental_filename(slugify(name) + '.' + format)
        if folder is not None:
            filename = os.path.join(folder, filename)
        self.as_signal().save(filename)
        if save_std is True:
            self.as_signal(field='std').save(append2pathname(filename, '_std'))

    def as_dictionary(self, fullcopy=True):
        """Returns parameter as a dictionary, saving all attributes from
        self._whitelist.keys() For more information see
        :meth:`hyperspy.misc.export_dictionary.export_to_dictionary`

        Parameters
        ----------
        fullcopy : Bool (optional, False)
            Copies of objects are stored, not references. If any found,
            functions will be pickled and signals converted to dictionaries
        Returns
        -------
        dic : dictionary with the following keys:
            _id_name : string
                _id_name of the original parameter, used to create the
                dictionary. Has to match with the self._id_name
            _twins : list
                a list of ids of the twins of the parameter
            _whitelist : dictionary
                a dictionary, which keys are used as keywords to match with the
                parameter attributes.  For more information see
                :meth:`hyperspy.misc.export_dictionary.export_to_dictionary`
            * any field from _whitelist.keys() *

        """
        dic = {'_twins': [id(t) for t in self._twins]}
        export_to_dictionary(self, self._whitelist, dic, fullcopy)
        return dic

    def default_traits_view(self):
        # As mentioned above, the default editor for
        # value = t.Property(t.Either([t.CFloat(0), Array()]))
        # gives a ValueError. We therefore implement default_traits_view so
        # that configure/edit_traits will still work straight out of the box.
        # A whitelist controls which traits to include in this view.
        from traitsui.api import RangeEditor, View, Item
        whitelist = ['bmax', 'bmin', 'free', 'name', 'std', 'units', 'value']
        editable_traits = [
            trait for trait in self.editable_traits() if trait in whitelist
        ]
        if 'value' in editable_traits:
            i = editable_traits.index('value')
            v = editable_traits.pop(i)
            editable_traits.insert(
                i,
                Item(v, editor=RangeEditor(low_name='bmin', high_name='bmax')))
        view = View(editable_traits, buttons=['OK', 'Cancel'])
        return view

    def _interactive_slider_bounds(self, index=None):
        """Guesstimates the bounds for the slider. They will probably have to
        be changed later by the user.
        """
        fraction = 10.
        _min, _max, step = None, None, None
        value = self.value if index is None else self.value[index]
        if self.bmin is not None:
            _min = self.bmin
        if self.bmax is not None:
            _max = self.bmax
        if _max is None and _min is not None:
            _max = value + fraction * (value - _min)
        if _min is None and _max is not None:
            _min = value - fraction * (_max - value)
        if _min is None and _max is None:
            if self is self.component._position:
                axis = self._axes_manager.signal_axes[-1]
                _min = axis.axis.min()
                _max = axis.axis.max()
                step = np.abs(axis.scale)
            else:
                _max = value + np.abs(value * fraction)
                _min = value - np.abs(value * fraction)
        if step is None:
            step = (_max - _min) * 0.001
        return {'min': _min, 'max': _max, 'step': step}

    def _interactive_update(self, value=None, index=None):
        """Callback function for the widgets, to update the value
        """
        if value is not None:
            if index is None:
                self.value = value['new']
            else:
                self.value = self.value[:index] + (value['new'],) +\
                    self.value[index + 1:]

    def notebook_interaction(self, display=True):
        """Creates interactive notebook widgets for the parameter, if
        available.
        Requires `ipywidgets` to be installed.
        Parameters
        ----------
        display : bool
            if True (default), attempts to display the parameter widget.
            Otherwise returns the formatted widget object.
        """
        from ipywidgets import VBox
        from traitlets import TraitError as TraitletError
        from IPython.display import display as ip_display
        try:
            if self._number_of_elements == 1:
                container = self._create_notebook_widget()
            else:
                children = [
                    self._create_notebook_widget(index=i)
                    for i in range(self._number_of_elements)
                ]
                container = VBox(children)
            if not display:
                return container
            ip_display(container)
        except TraitletError:
            if display:
                _logger.info('This function is only avialable when running in'
                             ' a notebook')
            else:
                raise

    def _create_notebook_widget(self, index=None):

        from ipywidgets import (FloatSlider, FloatText, Layout, HBox)

        widget_bounds = self._interactive_slider_bounds(index=index)
        thismin = FloatText(
            value=widget_bounds['min'],
            description='min',
            layout=Layout(flex='0 1 auto', width='auto'),
        )
        thismax = FloatText(
            value=widget_bounds['max'],
            description='max',
            layout=Layout(flex='0 1 auto', width='auto'),
        )
        current_value = self.value if index is None else self.value[index]
        current_name = self.name
        if index is not None:
            current_name += '_{}'.format(index)
        widget = FloatSlider(value=current_value,
                             min=thismin.value,
                             max=thismax.value,
                             step=widget_bounds['step'],
                             description=current_name,
                             layout=Layout(flex='1 1 auto', width='auto'))

        def on_min_change(change):
            if widget.max > change['new']:
                widget.min = change['new']
                widget.step = np.abs(widget.max - widget.min) * 0.001

        def on_max_change(change):
            if widget.min < change['new']:
                widget.max = change['new']
                widget.step = np.abs(widget.max - widget.min) * 0.001

        thismin.observe(on_min_change, names='value')
        thismax.observe(on_max_change, names='value')

        this_observed = functools.partial(self._interactive_update,
                                          index=index)

        widget.observe(this_observed, names='value')
        container = HBox((thismin, widget, thismax))
        return container
class HardwareAction(traits.HasTraits):
    """Parent class for all hardware actions. User must make a subclass of this for each
    hardware action and overwrite init, close and callback methods where necessary. Other
    functions can use the parent class implementation directly"""
    callbackTime = traits.Float()
    variables = traits.List()
    variablesReference = {}# leave empty. This will be set to the experiment control variables before the call back is executed. This is all taken care of by the snake
    hardwareActionName = traits.Str()
    examineVariablesButton = traits.Button()
    enabled = traits.Bool(True)
    callbackTimeVariableDependent = False # if true the calbackTime argument is a variable to be parsed in the snake
    callbackTimeString = None # gets populated if callbackTimeInSequence is a string
    snakeReference = None# reference to the snake object so that we can call update functions for e.g. examineVariablesDict pane
    
    def __init__(self, callbackTimeInSequence, **traitsDict):
        super(HardwareAction,self).__init__(**traitsDict)
        if type(callbackTimeInSequence) is float:
            self.callbackTime = callbackTimeInSequence # time in the sequence when call back should be performed passed during constructions
        elif type(callbackTimeInSequence) is str:#here we check if callback time is a timing edge or a variable
            self.callbackTimeVariableDependent = True
            self.callbackTimeString = callbackTimeInSequence
            logger.info( "CallbackTime string detected attempting to parse string as timing edge or variable" )
        else:
            self.callbackTime = callbackTimeInSequence # time in the sequence when call back should be performed passed during constructions
        self.awaitingCallback = True # goes to False after it has been called back for the final time in a sequence (usually once)
        self.callbackCounter = 0 # number of times called back this sequence
        self.initialised = False # set to true if init run. set to false if close run
        logger.info( "HardwareAction Super class __init__ completed" )
    
    def _variables_default(self):
        """uses the variable mappings dictionary defined in the subclass """
        return self.variableMappings.keys()

    def setVariablesDictionary(self, variables):
        """sets the variables reference to the latest variables dictionary. simply sets the variables reference attribute """        
        self.variablesReference=variables
    
    def mapVariables(self):
        """returns a dictionary of python variable names used in the callback function
        with their correct values for this run. Raises an error if a variable is missing.
        Could potentially implement default values here"""
        logger.debug( "variables in %s: %s" % (self.hardwareActionName,self.variablesReference))
        try:
            return {self.variableMappings[key]:self.variablesReference[key] for key in self.variableMappings.iterkeys()}
        except KeyError as e:
            raise e # defaults handling TODO  
            
    def parseCallbackTime(self):
        """if callback Time is a string we comprehend it as a timing edge name or variable name"""
        if self.callbackTimeString in self.snakeReference.timingEdges:
            self.callbackTime = self.snakeReference.timingEdges[self.callbackTimeString]
        elif self.callbackTimeString in self.snakeReference.variables:
            self.callbackTime = self.snakeReference.variables[self.callbackTimeString]
        else:
            raise KeyError("callbackTime %s was not found in either the timing edges or variables dictionary. Check Spelling? Could not initialise %s object" % (self.callbackTimeString, self.hardwareActionName))

    #####USER SHOULD OVERWRITE THE BELOW FUNCTIONS IN SUBCLASS AS REQUIRED
    def init(self):
        """only called once when the user presses the start button. This should perform
        any hardware specific initialisation. E.g opening sockets / decads connections. Return 
        string is printed to main log terminal"""
        self.initialised=True
        logger.warning("Using default init as no init method has been defined in Hardware Action Subclass")
        return "%s init successful" % self.hardwareActionName

    def close(self):
        """called to close the hardware down when user stops Snake or exits. Should
        safely close the hardware. It should be able to restart again when the 
        init function is called (e.g. user then presses start"""
        logger.warning("Using default close as no close method has been defined in Hardware Action Subclass")
        return "%s closed" % self.hardwareActionName
        
    def callback(self):
        """This is the function that is called every sequence at the callbackTime. 
        IT SHOULD NOT HANG as this is a blocking function call. You need to handle
        threading yourself if the callback function would take a long time to return.
        Return value should be a string to be printed in terminal"""
        logger.debug( "beginning %s callback" % self.hardwareActionName)        
        if not self.initialised:
            return "%s not initialised with init function. Cannot be called back until initialised. Doing nothing" % self.hardwareActionName
        try:#YOUR CALLBACK CODE SHOULD GO IN THIS TRY BLOCK!
            self.finalVariables = self.mapVariables()
            raise NotImplementedError("the callback function needs to be implemented in your subclass")
            return "callback on %s completed" % (self.hardwareActionName)
        except KeyboardInterrupt:
            raise
        except KeyError as e:
            return "Failed to find variable %s in variables %s. Check variable is defined in experiment control " % (e.message, self.variablesReference.keys())
        except Exception as e:
            return "Failed to perform callback on %s. Error message %s" % (self.hardwareActionName, e.message)
        
        
    def _enabled_changed(self):
        """traitsui handler function (is automatically called when enabled changes during interaction with user interface """
        if self.enabled:
            self.snakeReference.mainLog.addLine("%s was just enabled. Will perform its init method" % self.hardwareActionName,1)
            self.awaitingCallback=False # by setting this to False we prevent the action being performed till the next sequence begins. This is usually desireable            
            self.init()
        elif not self.enabled:
            if self.snakeReference.isRunning:#only print to log if it's disabled while snake is running
                self.snakeReference.mainLog.addLine("%s was just disabled. Will perform its close method" % self.hardwareActionName,1)
            self.close()#close method always performed for safety
            
    def _examineVariablesButton_fired(self):
        """Called when user clicks on book item near hardware action name. This makes a pop up
        which shows all the variables that the hardware action defines. later it might let users
        edit certain parameters"""
        self.snakeReference.updateExamineVariablesDictionary(self)# pass the update this hardwareAction object as the argument
        logger.info("variables = %s" % self.variables)
    
    #traits_view for all hardware actions. Just shows the name and lets the user enable or disable        
    traits_view = traitsui.View(
                    traitsui.HGroup(traitsui.Item("hardwareActionName", show_label=False, style="readonly"),
                                    traitsui.Item("enabled",show_label=False),
                                    traitsui.Item("examineVariablesButton",show_label=False,
                                                  editor=traitsui.ButtonEditor(image = pyface.image_resource.ImageResource( os.path.join(os.getcwd(), 'icons', 'book.png' ))),
                                                   style="custom"), 
                                   )
                               )
예제 #14
0
class HCFT(tr.HasStrictTraits):
    '''High-Cycle Fatigue Tool
    '''
    #=========================================================================
    # Traits definitions
    #=========================================================================
    decimal = tr.Enum(',', '.')
    delimiter = tr.Str(';')
    records_per_second = tr.Float(100)
    take_time_from_time_column = tr.Bool(True)
    file_csv = tr.File
    open_file_csv = tr.Button('Input file')
    skip_first_rows = tr.Range(low=1, high=10**9, mode='spinner')
    columns_headers_list = tr.List([])
    x_axis = tr.Enum(values='columns_headers_list')
    y_axis = tr.Enum(values='columns_headers_list')
    force_column = tr.Enum(values='columns_headers_list')
    time_column = tr.Enum(values='columns_headers_list')
    x_axis_multiplier = tr.Enum(1, -1)
    y_axis_multiplier = tr.Enum(-1, 1)
    npy_folder_path = tr.Str
    file_name = tr.Str
    apply_filters = tr.Bool
    plot_settings_btn = tr.Button
    plot_settings = PlotSettings()
    plot_settings_active = tr.Bool
    normalize_cycles = tr.Bool
    smooth = tr.Bool
    plot_every_nth_point = tr.Range(low=1, high=1000000, mode='spinner')
    old_peak_force_before_cycles = tr.Float
    peak_force_before_cycles = tr.Float
    window_length = tr.Range(low=1, high=10**9 - 1, value=31, mode='spinner')
    polynomial_order = tr.Range(low=1, high=10**9, value=2, mode='spinner')
    activate = tr.Bool(False)
    add_plot = tr.Button
    add_creep_plot = tr.Button(desc='Creep plot of X axis array')
    clear_plot = tr.Button
    parse_csv_to_npy = tr.Button
    generate_filtered_and_creep_npy = tr.Button
    add_columns_average = tr.Button
    force_max = tr.Float(100)
    force_min = tr.Float(40)
    min_cycle_force_range = tr.Float(50)
    cutting_method = tr.Enum(
        'Define min cycle range(force difference)', 'Define Max, Min')
    columns_to_be_averaged = tr.List
    figure = tr.Instance(mpl.figure.Figure)
    log = tr.Str('')
    clear_log = tr.Button

    def _figure_default(self):
        figure = mpl.figure.Figure(facecolor='white')
        figure.set_tight_layout(True)
        return figure

    #=========================================================================
    # File management
    #=========================================================================

    def _open_file_csv_fired(self):
        try:

            self.reset()

            """ Handles the user clicking the 'Open...' button.
            """
            extns = ['*.csv', ]  # seems to handle only one extension...
            wildcard = '|'.join(extns)

            dialog = FileDialog(title='Select text file',
                                action='open', wildcard=wildcard,
                                default_path=self.file_csv)

            result = dialog.open()

            """ Test if the user opened a file to avoid throwing an exception if he 
            doesn't """
            if result == OK:
                self.file_csv = dialog.path
            else:
                return

            """ Filling x_axis and y_axis with values """
            headers_array = np.array(
                pd.read_csv(
                    self.file_csv, delimiter=self.delimiter, decimal=self.decimal,
                    nrows=1, header=None
                )
            )[0]
            for i in range(len(headers_array)):
                headers_array[i] = self.get_valid_file_name(headers_array[i])
            self.columns_headers_list = list(headers_array)

            """ Saving file name and path and creating NPY folder """
            dir_path = os.path.dirname(self.file_csv)
            self.npy_folder_path = os.path.join(dir_path, 'NPY')
            if os.path.exists(self.npy_folder_path) == False:
                os.makedirs(self.npy_folder_path)

            self.file_name = os.path.splitext(
                os.path.basename(self.file_csv))[0]

        except Exception as e:
            self.deal_with_exception(e)

    def _parse_csv_to_npy_fired(self):
        # Run method on different thread so GUI doesn't freeze
        #thread = Thread(target = threaded_function, function_args = (10,))
        thread = Thread(target=self.parse_csv_to_npy_fired)
        thread.start()

    def parse_csv_to_npy_fired(self):
        try:
            self.print_custom('Parsing csv into npy files...')

            for i in range(len(self.columns_headers_list) -
                           len(self.columns_to_be_averaged)):
                current_column_name = self.columns_headers_list[i]
                column_array = np.array(pd.read_csv(
                    self.file_csv, delimiter=self.delimiter, decimal=self.decimal,
                    skiprows=self.skip_first_rows, usecols=[i]))

                if current_column_name == self.time_column and \
                        self.take_time_from_time_column == False:
                    column_array = np.arange(start=0.0,
                                             stop=len(column_array) /
                                             self.records_per_second,
                                             step=1.0 / self.records_per_second)

                np.save(os.path.join(self.npy_folder_path, self.file_name +
                                     '_' + current_column_name + '.npy'),
                        column_array)

            """ Exporting npy arrays of averaged columns """
            for columns_names in self.columns_to_be_averaged:
                temp = np.zeros((1))
                for column_name in columns_names:
                    temp = temp + np.load(os.path.join(self.npy_folder_path,
                                                       self.file_name +
                                                       '_' + column_name +
                                                       '.npy')).flatten()
                avg = temp / len(columns_names)

                avg_file_suffex = self.get_suffex_for_columns_to_be_averaged(
                    columns_names)
                np.save(os.path.join(self.npy_folder_path, self.file_name +
                                     '_' + avg_file_suffex + '.npy'), avg)

            self.print_custom('Finsihed parsing csv into npy files.')
        except Exception as e:
            self.deal_with_exception(e)

    def get_suffex_for_columns_to_be_averaged(self, columns_names):
        suffex_for_saved_file_name = 'avg_' + '_'.join(columns_names)
        return suffex_for_saved_file_name

    def get_valid_file_name(self, original_file_name):
        valid_chars = "-_.() %s%s" % (string.ascii_letters, string.digits)
        new_valid_file_name = ''.join(
            c for c in original_file_name if c in valid_chars)
        return new_valid_file_name

    def _clear_plot_fired(self):
        self.figure.clear()
        self.data_changed = True

    def _add_columns_average_fired(self):
        try:
            columns_average = ColumnsAverage()
            for name in self.columns_headers_list:
                columns_average.columns.append(Column(column_name=name))

            # kind='modal' pauses the implementation until the window is closed
            columns_average.configure_traits(kind='modal')

            columns_to_be_averaged_temp = []
            for i in columns_average.columns:
                if i.selected:
                    columns_to_be_averaged_temp.append(i.column_name)

            if columns_to_be_averaged_temp:  # If it's not empty
                self.columns_to_be_averaged.append(columns_to_be_averaged_temp)

                avg_file_suffex = self.get_suffex_for_columns_to_be_averaged(
                    columns_to_be_averaged_temp)
                self.columns_headers_list.append(avg_file_suffex)
        except Exception as e:
            self.deal_with_exception(e)

    def _generate_filtered_and_creep_npy_fired(self):
        # Run method on different thread so GUI doesn't freeze
        #thread = Thread(target = threaded_function, function_args = (10,))
        thread = Thread(target=self.generate_filtered_and_creep_npy_fired)
        thread.start()

    def generate_filtered_and_creep_npy_fired(self):
        try:
            if self.npy_files_exist(os.path.join(
                    self.npy_folder_path, self.file_name + '_' + self.force_column
                    + '.npy')) == False:
                return

            self.print_custom('Generating filtered and creep files...')

            # 1- Export filtered force
            force = np.load(os.path.join(self.npy_folder_path,
                                         self.file_name + '_' + self.force_column
                                         + '.npy')).flatten()
            peak_force_before_cycles_index = np.where(
                abs((force)) > abs(self.peak_force_before_cycles))[0][0]
            force_ascending = force[0:peak_force_before_cycles_index]
            force_rest = force[peak_force_before_cycles_index:]

            force_max_indices, force_min_indices = self.get_array_max_and_min_indices(
                force_rest)

            force_max_min_indices = np.concatenate(
                (force_min_indices, force_max_indices))
            force_max_min_indices.sort()

            force_rest_filtered = force_rest[force_max_min_indices]
            force_filtered = np.concatenate(
                (force_ascending, force_rest_filtered))
            np.save(os.path.join(self.npy_folder_path, self.file_name +
                                 '_' + self.force_column + '_filtered.npy'),
                    force_filtered)

            # 2- Export filtered displacements
            for i in range(0, len(self.columns_headers_list)):
                if self.columns_headers_list[i] != self.force_column and \
                        self.columns_headers_list[i] != self.time_column:

                    disp = np.load(os.path.join(self.npy_folder_path, self.file_name
                                                + '_' +
                                                self.columns_headers_list[i]
                                                + '.npy')).flatten()
                    disp_ascending = disp[0:peak_force_before_cycles_index]
                    disp_rest = disp[peak_force_before_cycles_index:]

                    if self.activate == True:
                        disp_ascending = savgol_filter(
                            disp_ascending, window_length=self.window_length,
                            polyorder=self.polynomial_order)

                    disp_rest_filtered = disp_rest[force_max_min_indices]
                    filtered_disp = np.concatenate(
                        (disp_ascending, disp_rest_filtered))
                    np.save(os.path.join(self.npy_folder_path, self.file_name + '_'
                                         + self.columns_headers_list[i] +
                                         '_filtered.npy'), filtered_disp)

            # 3- Export creep for displacements
            # Cutting unwanted max min values to get correct full cycles and remove
            # false min/max values caused by noise
            if self.cutting_method == "Define Max, Min":
                force_max_indices_cutted, force_min_indices_cutted = \
                    self.cut_indices_of_min_max_range(force_rest,
                                                      force_max_indices,
                                                      force_min_indices,
                                                      self.force_max,
                                                      self.force_min)
            elif self.cutting_method == "Define min cycle range(force difference)":
                force_max_indices_cutted, force_min_indices_cutted = \
                    self.cut_indices_of_defined_range(force_rest,
                                                      force_max_indices,
                                                      force_min_indices,
                                                      self.min_cycle_force_range)

            self.print_custom("Cycles number= ", len(force_min_indices))
            self.print_custom("Cycles number after cutting fake cycles = ",
                              len(force_min_indices_cutted))

            for i in range(0, len(self.columns_headers_list)):
                if self.columns_headers_list[i] != self.time_column:
                    array = np.load(os.path.join(self.npy_folder_path, self.file_name +
                                                 '_' +
                                                 self.columns_headers_list[i]
                                                 + '.npy')).flatten()
                    array_rest = array[peak_force_before_cycles_index:]
                    array_rest_maxima = array_rest[force_max_indices_cutted]
                    array_rest_minima = array_rest[force_min_indices_cutted]
                    np.save(os.path.join(self.npy_folder_path, self.file_name + '_' +
                                         self.columns_headers_list[i] + '_max.npy'), array_rest_maxima)
                    np.save(os.path.join(self.npy_folder_path, self.file_name + '_' +
                                         self.columns_headers_list[i] + '_min.npy'), array_rest_minima)

            self.print_custom('Filtered and creep npy files are generated.')
        except Exception as e:
            self.deal_with_exception(e)

    def cut_indices_of_min_max_range(self, array, max_indices, min_indices,
                                     range_upper_value, range_lower_value):
        cutted_max_indices = []
        cutted_min_indices = []

        for max_index in max_indices:
            if abs(array[max_index]) > abs(range_upper_value):
                cutted_max_indices.append(max_index)
        for min_index in min_indices:
            if abs(array[min_index]) < abs(range_lower_value):
                cutted_min_indices.append(min_index)
        return cutted_max_indices, cutted_min_indices

    def cut_indices_of_defined_range(self, array, max_indices, min_indices, range_):
        cutted_max_indices = []
        cutted_min_indices = []

        for max_index, min_index in zip(max_indices, min_indices):
            if abs(array[max_index] - array[min_index]) > range_:
                cutted_max_indices.append(max_index)
                cutted_min_indices.append(min_index)

        if max_indices.size > min_indices.size:
            cutted_max_indices.append(max_indices[-1])
        elif min_indices.size > max_indices.size:
            cutted_min_indices.append(min_indices[-1])

        return cutted_max_indices, cutted_min_indices

    def get_array_max_and_min_indices(self, input_array):

        # Checking dominant sign
        positive_values_count = np.sum(np.array(input_array) >= 0)
        negative_values_count = input_array.size - positive_values_count

        # Getting max and min indices
        if (positive_values_count > negative_values_count):
            force_max_indices = self.get_max_indices(input_array)
            force_min_indices = self.get_min_indices(input_array)
        else:
            force_max_indices = self.get_min_indices(input_array)
            force_min_indices = self.get_max_indices(input_array)

        return force_max_indices, force_min_indices

    def get_max_indices(self, a):
        # This method doesn't qualify first and last elements as max
        max_indices = []
        i = 1
        while i < a.size - 1:
            previous_element = a[i - 1]

            # Skip repeated elements and record previous element value
            first_repeated_element = True

            while a[i] == a[i + 1] and i < a.size - 1:
                if first_repeated_element:
                    previous_element = a[i - 1]
                    first_repeated_element = False
                if i < a.size - 2:
                    i += 1
                else:
                    break

            if a[i] > a[i + 1] and a[i] > previous_element:
                max_indices.append(i)
            i += 1
        return np.array(max_indices)

    def get_min_indices(self, a):
        # This method doesn't qualify first and last elements as min
        min_indices = []
        i = 1
        while i < a.size - 1:
            previous_element = a[i - 1]

            # Skip repeated elements and record previous element value
            first_repeated_element = True
            while a[i] == a[i + 1]:
                if first_repeated_element:
                    previous_element = a[i - 1]
                    first_repeated_element = False
                if i < a.size - 2:
                    i += 1
                else:
                    break

            if a[i] < a[i + 1] and a[i] < previous_element:
                min_indices.append(i)
            i += 1
        return np.array(min_indices)

    def _activate_changed(self):
        if self.activate == False:
            self.old_peak_force_before_cycles = self.peak_force_before_cycles
            self.peak_force_before_cycles = 0
        else:
            self.peak_force_before_cycles = self.old_peak_force_before_cycles

    def _window_length_changed(self, new):

        if new <= self.polynomial_order:
            dialog = MessageDialog(
                title='Attention!',
                message='Window length must be bigger than polynomial order.')
            dialog.open()

        if new % 2 == 0 or new <= 0:
            dialog = MessageDialog(
                title='Attention!',
                message='Window length must be odd positive integer.')
            dialog.open()

    def _polynomial_order_changed(self, new):
        if new >= self.window_length:
            dialog = MessageDialog(
                title='Attention!',
                message='Polynomial order must be smaller than window length.')
            dialog.open()

    #=========================================================================
    # Plotting
    #=========================================================================

    def _plot_settings_btn_fired(self):
        try:
            self.plot_settings.configure_traits(kind='modal')
        except Exception as e:
            self.deal_with_exception(e)

    def npy_files_exist(self, path):
        if os.path.exists(path) == True:
            return True
        else:
            # TODO fix this
            self.print_custom(
                'Please parse csv file to generate npy files first.')
#             dialog = MessageDialog(
#                 title='Attention!',
#                 message='Please parse csv file to generate npy files first.')
#             dialog.open()
            return False

    def filtered_and_creep_npy_files_exist(self, path):
        if os.path.exists(path) == True:
            return True
        else:
            # TODO fix this
            self.print_custom(
                'Please generate filtered and creep npy files first.')
#             dialog = MessageDialog(
#                 title='Attention!',
#                 message='Please generate filtered and creep npy files first.')
#             dialog.open()
            return False

    data_changed = tr.Event

    def _add_plot_fired(self):
        # Run method on different thread so GUI doesn't freeze
        #thread = Thread(target = threaded_function, function_args = (10,))
        thread = Thread(target=self.add_plot_fired)
        thread.start()

    def add_plot_fired(self):
        try:
            if self.apply_filters:
                if self.filtered_and_creep_npy_files_exist(os.path.join(
                        self.npy_folder_path, self.file_name + '_' + self.x_axis
                        + '_filtered.npy')) == False:
                    return
                x_axis_name = self.x_axis + '_filtered'
                y_axis_name = self.y_axis + '_filtered'
                self.print_custom('Loading npy files...')
                # when mmap_mode!=None, the array will be loaded as 'numpy.memmap'
                # object which doesn't load the array to memory until it's
                # indexed
                x_axis_array = np.load(os.path.join(self.npy_folder_path,
                                                    self.file_name + '_' + self.x_axis
                                                    + '_filtered.npy'), mmap_mode='r')
                y_axis_array = np.load(os.path.join(self.npy_folder_path,
                                                    self.file_name + '_' + self.y_axis
                                                    + '_filtered.npy'), mmap_mode='r')
            else:
                if self.npy_files_exist(os.path.join(
                        self.npy_folder_path, self.file_name + '_' + self.x_axis
                        + '.npy')) == False:
                    return

                x_axis_name = self.x_axis
                y_axis_name = self.y_axis
                self.print_custom('Loading npy files...')
                # when mmap_mode!=None, the array will be loaded as 'numpy.memmap'
                # object which doesn't load the array to memory until it's
                # indexed
                x_axis_array = np.load(os.path.join(self.npy_folder_path,
                                                    self.file_name + '_' + self.x_axis
                                                    + '.npy'), mmap_mode='r')
                y_axis_array = np.load(os.path.join(self.npy_folder_path,
                                                    self.file_name + '_' + self.y_axis
                                                    + '.npy'), mmap_mode='r')

            if self.plot_settings_active:
                print(self.plot_settings.first_rows)
                print(self.plot_settings.distance)
                print(self.plot_settings.num_of_rows_after_each_distance)
                print(np.size(x_axis_array))
                indices = self.get_indices_array(np.size(x_axis_array),
                                                 self.plot_settings.first_rows,
                                                 self.plot_settings.distance,
                                                 self.plot_settings.num_of_rows_after_each_distance)
                x_axis_array = self.x_axis_multiplier * x_axis_array[indices]
                y_axis_array = self.y_axis_multiplier * y_axis_array[indices]
            else:
                x_axis_array = self.x_axis_multiplier * x_axis_array
                y_axis_array = self.y_axis_multiplier * y_axis_array

            self.print_custom('Adding Plot...')
            mpl.rcParams['agg.path.chunksize'] = 10000

            ax = self.figure.add_subplot(1, 1, 1)

            ax.set_xlabel(x_axis_name)
            ax.set_ylabel(y_axis_name)
            ax.plot(x_axis_array, y_axis_array, 'k',
                    linewidth=1.2, color=np.random.rand(3), label=self.file_name +
                    ', ' + x_axis_name)

            ax.legend()
            self.data_changed = True
            self.print_custom('Finished adding plot.')

        except Exception as e:
            self.deal_with_exception(e)

    def _add_creep_plot_fired(self):
        # Run method on different thread so GUI doesn't freeze
        #thread = Thread(target = threaded_function, function_args = (10,))
        thread = Thread(target=self.add_creep_plot_fired)
        thread.start()

    def add_creep_plot_fired(self):
        try:
            if self.filtered_and_creep_npy_files_exist(os.path.join(
                    self.npy_folder_path, self.file_name + '_' + self.x_axis
                    + '_max.npy')) == False:
                return

            self.print_custom('Loading npy files...')
            disp_max = self.x_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.x_axis + '_max.npy'))
            disp_min = self.x_axis_multiplier * \
                np.load(os.path.join(self.npy_folder_path,
                                     self.file_name + '_' + self.x_axis + '_min.npy'))
            complete_cycles_number = disp_max.size

            self.print_custom('Adding creep-fatigue plot...')
            mpl.rcParams['agg.path.chunksize'] = 10000

            ax = self.figure.add_subplot(1, 1, 1)

            ax.set_xlabel('Cycles number')
            ax.set_ylabel(self.x_axis)

            if self.plot_every_nth_point > 1:
                disp_max = disp_max[0::self.plot_every_nth_point]
                disp_min = disp_min[0::self.plot_every_nth_point]

            if self.smooth:
                # Keeping the first item of the array and filtering the rest
                disp_max = np.concatenate((
                    np.array([disp_max[0]]),
                    savgol_filter(disp_max[1:],
                                  window_length=self.window_length,
                                  polyorder=self.polynomial_order)
                ))
                disp_min = np.concatenate((
                    np.array([disp_min[0]]),
                    savgol_filter(disp_min[1:],
                                  window_length=self.window_length,
                                  polyorder=self.polynomial_order)
                ))

            if self.normalize_cycles:
                ax.plot(np.linspace(0, 1., disp_max.size), disp_max,
                        'k', linewidth=1.2, color=np.random.rand(3), label='Max'
                        + ', ' + self.file_name + ', ' + self.x_axis)
                ax.plot(np.linspace(0, 1., disp_min.size), disp_min,
                        'k', linewidth=1.2, color=np.random.rand(3), label='Min'
                        + ', ' + self.file_name + ', ' + self.x_axis)
            else:
                ax.plot(np.linspace(0, complete_cycles_number,
                                    disp_max.size), disp_max,
                        'k', linewidth=1.2, color=np.random.rand(3), label='Max'
                        + ', ' + self.file_name + ', ' + self.x_axis)
                ax.plot(np.linspace(0, complete_cycles_number,
                                    disp_min.size), disp_min,
                        'k', linewidth=1.2, color=np.random.rand(3), label='Min'
                        + ', ' + self.file_name + ', ' + self.x_axis)

            ax.legend()
            self.data_changed = True
            self.print_custom('Finished adding creep-fatigue plot.')

        except Exception as e:
            self.deal_with_exception(e)

    def get_indices_array(self,
                          array_size,
                          first_rows,
                          distance,
                          num_of_rows_after_each_distance):
        result_1 = np.arange(first_rows)
        result_2 = np.arange(start=first_rows, stop=array_size,
                             step=distance + num_of_rows_after_each_distance)
        result_2_updated = np.array([], dtype=np.int_)

        for result_2_value in result_2:
            data_slice = np.arange(result_2_value, result_2_value +
                                   num_of_rows_after_each_distance)
            result_2_updated = np.concatenate((result_2_updated, data_slice))

        result = np.concatenate((result_1, result_2_updated))
        return result

    def reset(self):
        self.columns_to_be_averaged = []
        self.log = ''

    def print_custom(self, *input_args):
        print(*input_args)
        if self.log == '':
            self.log = ''.join(str(e) for e in list(input_args))
        else:
            self.log = self.log + '\n' + \
                ''.join(str(e) for e in list(input_args))

    def deal_with_exception(self, e):
        self.print_custom('SOMETHING WENT WRONG!')
        self.print_custom('--------- Error message: ---------')
        self.print_custom(traceback.format_exc())
        self.print_custom('----------------------------------')

    def _clear_log_fired(self):
        self.log = ''

    #=========================================================================
    # Configuration of the view
    #=========================================================================
    traits_view = ui.View(
        ui.HSplit(
            ui.VSplit(
                ui.VGroup(
                    ui.VGroup(
                        ui.Item('decimal'),
                        ui.Item('delimiter'),
                        ui.HGroup(
                            ui.UItem('open_file_csv', has_focus=True),
                            ui.UItem('file_csv', style='readonly', width=0.1)),
                        label='Importing csv file',
                        show_border=True)),
                ui.VGroup(
                    ui.VGroup(
                        ui.VGroup(
                            ui.Item('take_time_from_time_column'),
                            ui.Item('time_column',
                                    enabled_when='take_time_from_time_column == True'),
                            ui.Item('records_per_second',
                                    enabled_when='take_time_from_time_column == False'),
                            label='Time calculation',
                            show_border=True),
                        ui.UItem('add_columns_average'),
                        ui.Item('skip_first_rows'),
                        ui.UItem('parse_csv_to_npy', resizable=True),
                        label='Processing csv file',
                        show_border=True)),
                ui.VGroup(
                    ui.VGroup(
                        ui.HGroup(ui.Item('x_axis'), ui.Item(
                            'x_axis_multiplier')),
                        ui.HGroup(ui.Item('y_axis'), ui.Item(
                            'y_axis_multiplier')),
                        ui.VGroup(
                            ui.HGroup(ui.UItem('add_plot'),
                                      ui.Item('apply_filters'),
                                      ui.Item('plot_settings_btn',
                                              label='Settings',
                                              show_label=False,
                                              enabled_when='plot_settings_active == True'),
                                      ui.Item('plot_settings_active',
                                              show_label=False)
                                      ),
                            show_border=True,
                            label='Plotting X axis with Y axis'
                        ),
                        ui.VGroup(
                            ui.HGroup(ui.UItem('add_creep_plot'),
                                      ui.VGroup(
                                          ui.Item('normalize_cycles'),
                                          ui.Item('smooth'),
                                          ui.Item('plot_every_nth_point'))
                                      ),
                            show_border=True,
                            label='Plotting Creep-fatigue of X axis variable'
                        ),
                        ui.UItem('clear_plot', resizable=True),
                        show_border=True,
                        label='Plotting'))
            ),
            ui.VGroup(
                ui.Item('force_column'),
                ui.VGroup(ui.VGroup(
                    ui.Item('window_length'),
                    ui.Item('polynomial_order'),
                    enabled_when='activate == True or smooth == True'),
                    show_border=True,
                    label='Smoothing parameters (Savitzky-Golay filter):'
                ),
                ui.VGroup(ui.VGroup(
                    ui.Item('activate'),
                    ui.Item('peak_force_before_cycles',
                            enabled_when='activate == True')
                ),
                    show_border=True,
                    label='Smooth ascending branch for all displacements:'
                ),
                ui.VGroup(ui.Item('cutting_method'),
                          ui.VGroup(ui.Item('force_max'),
                                    ui.Item('force_min'),
                                    label='Max, Min:',
                                    show_border=True,
                                    enabled_when='cutting_method == "Define Max, Min"'),
                          ui.VGroup(ui.Item('min_cycle_force_range'),
                                    label='Min cycle force range:',
                                    show_border=True,
                                    enabled_when='cutting_method == "Define min cycle range(force difference)"'),
                          show_border=True,
                          label='Cut fake cycles for creep:'),

                ui.VSplit(
                    ui.UItem('generate_filtered_and_creep_npy'),
                    ui.VGroup(
                        ui.Item('log',
                                width=0.1, style='custom'),
                        ui.UItem('clear_log'))),
                show_border=True,
                label='Filters'
            ),
            ui.UItem('figure', editor=MPLFigureEditor(),
                     resizable=True,
                     springy=True,
                     width=0.8,
                     label='2d plots')
        ),
        title='High-cycle fatigue tool',
        resizable=True,
        width=0.85,
        height=0.7
    )
class config(HasTraits):
    uuid = traits.Str(desc="UUID")
    desc = traits.Str(desc="Workflow Description")
    # Directories
    working_dir = Directory(mandatory=True,
                            desc="Location of the Nipype working directory")
    sink_dir = Directory(os.path.abspath('.'),
                         mandatory=True,
                         desc="Location where the BIP will store the results")
    crash_dir = Directory(mandatory=False,
                          desc="Location to store crash files")
    json_sink = Directory(mandatory=False, desc="Location to store json_files")
    surf_dir = Directory(mandatory=True, desc="Freesurfer subjects directory")
    save_script_only = traits.Bool(False)
    # Execution

    run_using_plugin = Bool(
        False,
        usedefault=True,
        desc="True to run pipeline with plugin, False to run serially")
    plugin = traits.Enum("PBS",
                         "MultiProc",
                         "SGE",
                         "Condor",
                         usedefault=True,
                         desc="plugin to use, if run_using_plugin=True")
    plugin_args = traits.Dict({"qsub_args": "-q many"},
                              usedefault=True,
                              desc='Plugin arguments.')
    test_mode = Bool(
        False,
        mandatory=False,
        usedefault=True,
        desc='Affects whether where and if the workflow keeps its \
                            intermediary files. True to keep intermediary files. '
    )
    timeout = traits.Float(14.0)
    # Subjects

    #subjects= traits.List(traits.Str, mandatory=True, usedefault=True,
    #    desc="Subject id's. Note: These MUST match the subject id's in the \
    #                            Freesurfer directory. For simplicity, the subject id's should \
    #                            also match with the location of individual functional files.")

    datagrabber = traits.Instance(Data, ())
    # First Level

    interscan_interval = traits.Float()
    film_threshold = traits.Float()
    input_units = traits.Enum('scans', 'secs')
    is_sparse = traits.Bool(False)
    model_hrf = traits.Bool(True)
    stimuli_as_impulses = traits.Bool(True)
    use_temporal_deriv = traits.Bool(True)
    volumes_in_cluster = traits.Int(1)
    ta = traits.Float()
    tr = traits.Float()
    hpcutoff = traits.Float()
    scan_onset = traits.Int(0)
    scale_regressors = traits.Bool(True)
    #bases = traits.Dict({'dgamma':{'derivs': False}},use_default=True)
    bases = traits.Dict(
        {'dgamma': {
            'derivs': False
        }}, use_default=True
    )  #traits.Enum('dgamma','gamma','none'), traits.Enum(traits.Dict(traits.Enum('derivs',None), traits.Bool),None), desc="name of basis function and options e.g., {'dgamma': {'derivs': True}}")

    # preprocessing info
    preproc_config = traits.File(desc="preproc config file")
    use_compcor = traits.Bool(desc="use noise components from CompCor")
    #advanced_options
    use_advanced_options = Bool(False)
    advanced_options = traits.Code()
예제 #16
0
class Trainer(t.HasStrictTraits):
    model: models.BaseNet = t.Instance(torch.nn.Module, transient=True)

    def _model_default(self):

        # Merge 'base config' (if requested) and any overrides in 'model_config'
        if self.base_config:
            model_config = get_ref_arch(self.base_config)
        else:
            model_config = {}
        if self.model_config:
            model_config.update(self.model_config)
        if self.data_spec:
            model_config.update(
                {
                    "input_channels": self.data_spec["input_channels"],
                    "num_output_classes": [
                        s["num_classes"] for s in self.data_spec["output_spec"]
                    ],
                }
            )
        # create model accordingly
        model_class = getattr(models, self.model_class)
        return model_class(**model_config)

    base_config: str = t.Str()
    model_config: dict = t.Dict()
    model_class: str = t.Enum("FilterNet", "DeepConvLSTM")

    lr_exp: float = t.Float(-3.0)
    batch_size: int = t.Int()
    win_len: int = t.Int(512)
    n_samples_per_batch: int = t.Int(5000)
    train_step: int = t.Int(16)
    seed: int = t.Int()
    decimation: int = t.Int(1)
    optim_type: str = t.Enum(["Adam", "SGD, RMSprop"])
    loss_func: str = t.Enum(["cross_entropy", "binary_cross_entropy"])
    patience: int = t.Int(10)
    lr_decay: float = t.Float(0.95)
    weight_decay: float = t.Float(1e-4)
    alpha: float = t.Float(0.99)
    momentum: float = t.Float(0.25)
    validation_fold: int = t.Int()
    epoch_size: float = t.Float(2.0)
    y_cols: str = t.Str()
    sensor_subset: str = t.Str()

    has_null_class: bool = t.Bool()

    def _has_null_class_default(self):
        return self.data_spec["output_spec"][0]["classes"][0] in ("", "Null")

    predict_null_class: bool = t.Bool(True)

    _class_weights: torch.Tensor = t.Instance(torch.Tensor)

    def __class_weights_default(self):
        # Not weights for now because didn't seem to increase things significantly and
        #   added yet another hyper-parameter. Using zero didn't seem to work well.
        if False and self.has_null_class and not self.predict_null_class:
            cw = torch.ones(self.model.num_output_classes, device=self.device)
            cw[0] = 0.01
            cw /= cw.sum()
            return cw
        return None

    dataset: str = t.Enum(
        ["opportunity", "smartphone_hapt", "har", "intention_recognition"]
    )
    name: str = t.Str()

    def _name_default(self):
        import time

        modelstr = self.model.__class__.__name__
        timestr = time.strftime("%Y%m%d-%H%M%S")
        return f"{modelstr}_{timestr}"

    model_path: str = t.Str()

    def _model_path_default(self):
        return f"saved_models/{self.name}/"

    data_spec: dict = t.Any()
    epoch_iters: int = t.Int(0)
    train_state: TrainState = t.Instance(TrainState, ())
    cp_iter: int = t.Int()

    cuda: bool = t.Bool(transient=True)

    def _cuda_default(self):
        return torch.cuda.is_available()

    device: str = t.Str(transient=True)

    def _device_default(self):
        return "cuda" if self.cuda else "cpu"

    dl_train: DataLoader = t.Instance(DataLoader, transient=True)

    def _dl_train_default(self):
        return self._get_dl("train")

    dl_val: DataLoader = t.Instance(DataLoader, transient=True)

    def _dl_val_default(self):
        return self._get_dl("val")

    dl_test: DataLoader = t.Instance(DataLoader, transient=True)

    def _dl_test_default(self):
        return self._get_dl("test")

    def _get_dl(self, s):

        if self.dataset == "opportunity":
            from filternet.datasets.opportunity import get_x_y_contig
        elif self.dataset == "smartphone_hapt":
            from filternet.datasets.smartphone_hapt import get_x_y_contig
        elif self.dataset == "har":
            from filternet.datasets.har import get_x_y_contig
        elif self.dataset == "intention_recognition":
            from filternet.datasets.intention_recognition import get_x_y_contig
        else:
            raise ValueError(f"Unknown dataset {self.dataset}")

        kwargs = {}
        if self.y_cols:
            kwargs["y_cols"] = self.y_cols
        if self.sensor_subset:
            kwargs["sensor_subset"] = self.sensor_subset

        Xc, ycs, data_spec = get_x_y_contig(s, **kwargs)

        if s == "train":
            # Training shuffles, and we set epoch size to length of the dataset. We can set train_step as
            # small as we want to get more windows; we'll only run len(Sc)/win_len of them in each training
            # epoch.
            self.epoch_iters = int(len(Xc) / self.decimation)
            X, ys = sliding_window_x_y(
                Xc, ycs, win_len=self.win_len, step=self.train_step, shuffle=False
            )
            # Set the overall data spec using the training set,
            #  and modify later if more info is needed.
            self.data_spec = data_spec
        else:
            # Val and test data are not shuffled.
            # Each point is inferred ~twice b/c step = win_len/2
            X, ys = sliding_window_x_y(
                Xc,
                ycs,
                win_len=self.win_len,
                step=int(self.win_len / 2),
                shuffle=False,  # Cannot be true with windows
            )

        dl = DataLoader(
            TensorDataset(torch.Tensor(X), *[torch.Tensor(y).long() for y in ys]),
            batch_size=self.batch_size,
            shuffle=True if s == "train" else False,
        )
        return dl

    def _batch_size_default(self):
        batch_size = int(self.n_samples_per_batch / self.win_len)
        print(f"Batch size: {batch_size}")
        return batch_size

    optimizer = t.Any(transient=True)

    def _optimizer_default(self):
        if self.optim_type == "SGD":
            optimizer = torch.optim.SGD(
                self.model.parameters(),
                lr=10 ** (self.lr_exp),
                momentum=self.momentum,
                weight_decay=self.weight_decay,
            )
        elif self.optim_type == "Adam":
            optimizer = torch.optim.Adam(
                self.model.parameters(),
                lr=10 ** (self.lr_exp),
                weight_decay=self.weight_decay,
                amsgrad=True,
            )
        elif self.optim_type == "RMSprop":
            optimizer = torch.optim.RMSprop(
                self.model.parameters(),
                lr=10 ** (self.lr_exp),
                alpha=self.alpha,
                weight_decay=self.weight_decay,
                momentum=self.momentum,
            )
        else:
            raise NotImplementedError(self.optim_type)
        return optimizer

    iteration: int = t.Property(t.Int)

    def _get_iteration(self):
        return len(self.train_state.epoch_records) + 1

    lr_scheduler = t.Any(transient=True)

    def _lr_scheduler_default(self):
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer, self.lr_decay  # , last_epoch=self._iteration
        )

        # If this is being re-instantiated in mid-training, then we must
        #  iterate scheduler forward to match the training step.
        for i in range(self.iteration):
            if self.lr_decay != 1:
                lr_scheduler.step()

        return lr_scheduler

    #####
    # Training Methods
    ##
    def _train_batch(self, data, targets):
        self.optimizer.zero_grad()
        loss, output, _targets, _ = self._run_model_on_batch(data, targets)
        loss.backward()
        self.optimizer.step()
        # if self.max_lr:
        #     self.lr_scheduler.step()

        return loss, output, _targets

    def _run_model_on_batch(self, data, targets):
        targets = torch.stack(targets)

        if self.cuda:
            data, targets = data.cuda(), targets.cuda()

        output = self.model(data)

        _targets = self.model.transform_targets(targets, one_hot=False)
        if self.loss_func == "cross_entropy":
            _losses = [
                F.cross_entropy(o, t, weight=self._class_weights)
                for o, t in zip(output, _targets)
            ]
            loss = sum(_losses)
        elif self.loss_func == "binary_cross_entropy":
            _targets_onehot = self.model.transform_targets(targets, one_hot=True)
            _losses = [
                F.binary_cross_entropy_with_logits(o, t, weight=self._class_weights)
                for o, t in zip(output, _targets_onehot)
            ]
            loss = sum(_losses)
        else:
            raise NotImplementedError(self.loss)

        # Assume only 1 output:

        return loss, output[0], _targets[0], _losses[0]

    def _calc_validation_loss(self):
        running_loss = 0
        self.model.eval()
        with torch.no_grad():
            for batch_idx, (data, *targets) in enumerate(self.dl_val):
                loss, _, _, _ = self._run_model_on_batch(data, targets)
                running_loss += loss.item() * data.size(0)

        return running_loss / len(self.dl_val.dataset)

    def _train_epoch(self):

        self.model.train()

        train_losses = []
        train_accs = []

        for batch_idx, (data, *targets) in enumerate(self.dl_train):
            if (
                batch_idx * data.shape[0] * data.shape[2]
                > self.epoch_iters * self.epoch_size
            ):
                # we've effectively finished one epoch worth of data; break!
                break

            batch_loss, batch_output, batch_targets = self._train_batch(data, targets)
            train_losses.append(batch_loss.detach().cpu().item())
            batch_preds = torch.argmax(batch_output, 1, False)
            train_accs.append(
                (batch_preds == batch_targets).detach().cpu().float().mean().item()
            )

        if self.lr_decay != 1:
            self.lr_scheduler.step()

        return EpochMetrics(loss=np.mean(train_losses), acc=np.mean(train_accs))

    def _val_epoch(self):
        return self._eval_epoch(self.dl_val)

    def _eval_epoch(self, data_loader):
        # Validation
        self.model.eval()

        losses = []
        outputs = []
        targets = []

        with torch.no_grad():
            for batch_idx, (data, *target) in enumerate(data_loader):
                (
                    batch_loss,
                    batch_output,
                    batch_targets,
                    train_losses,
                ) = self._run_model_on_batch(data, target)

                losses.append(batch_loss.detach().cpu().item())
                outputs.append(
                    torch.argmax(batch_output, 1, False)
                    .detach()
                    .cpu()
                    .data.numpy()
                    .flatten()
                )
                targets.append(batch_targets.detach().cpu().data.numpy().flatten())

        targets = np.hstack(targets)
        outputs = np.hstack(outputs)
        acc = sklearn.metrics.accuracy_score(targets, outputs)
        f1 = sklearn.metrics.f1_score(targets, outputs, average="weighted")

        return EpochMetrics(loss=np.mean(losses), acc=acc, f1=f1)

    def init_data(self):
        # Initiate loading of datasets, model
        _, _, _ = self.dl_train, self.dl_val, self.dl_test
        _ = self.model

    def init_train(self):

        # initialization
        if self.seed:
            torch.manual_seed(self.seed)
        if self.cuda:
            if self.seed:
                torch.cuda.manual_seed(self.seed)
        self.model.to(self.device)

    def train_one_epoch(self, verbose=True) -> EpochRecord:
        """ traing a single epoch -- method tailored to the Ray.tune methodology."""
        epoch_record = EpochRecord(epoch=len(self.train_state.epoch_records))
        self.train_state.epoch_records.append(epoch_record)

        with Timer("Train Epoch", log_output=verbose) as t:
            epoch_record.train = self._train_epoch()
        epoch_record.iter_s_cpu = t.interval_cpu
        epoch_record.iter_s_wall = t.interval_wall
        epoch_record.lr = self.optimizer.param_groups[0]["lr"]

        with Timer("Val Epoch", log_output=verbose):
            epoch_record.val = self._val_epoch()

        df = self.train_state.to_df()

        # Early stopping / checkpointing implementation
        df["raw_metric"] = df.val_loss / df.val_f1
        df["ewma_smoothed_loss"] = (
            df["raw_metric"].ewm(ignore_na=False, halflife=3).mean()
        )
        df["instability_penalty"] = (
            df["raw_metric"].rolling(5, min_periods=3).std().fillna(0.75)
        )
        stopping_metric = df["stopping_metric"] = (
            df["ewma_smoothed_loss"] + df["instability_penalty"]
        )
        epoch_record.stopping_metric = df["stopping_metric"].iloc[-1]

        idx_this_iter = stopping_metric.index.max()
        idx_best_yet = stopping_metric.idxmin()
        self.train_state.best_sm = df.loc[idx_best_yet, "stopping_metric"]
        self.train_state.best_loss = df.loc[idx_best_yet, "val_loss"]
        self.train_state.best_f1 = df.loc[idx_best_yet, "val_f1"]

        if idx_best_yet == idx_this_iter:
            # Best yet! Checkpoint.
            epoch_record.should_checkpoint = True
            self.cp_iter = epoch_record.epoch

        else:
            if self.patience is not None:
                patience_counter = idx_this_iter - idx_best_yet
                assert patience_counter >= 0
                if patience_counter > self.patience:
                    if verbose:
                        print(
                            f"Early stop! Out of patience ( {patience_counter} > {self.patience} )"
                        )
                    epoch_record.done = True

        if verbose:
            self.print_train_summary()

        return epoch_record

    def train(self, max_epochs=50, verbose=True):
        """ A pretty standard training loop, constrained to stop in `max_epochs` but may stop early if our
        custom stopping metric does not improve for `self.patience` epochs. Always checkpoints
        when a new best stopping_metric is achieved. An alternative to using
        ray.tune for training."""

        self.init_data()
        self.init_train()

        while True:
            epoch_record = self.train_one_epoch(verbose=verbose)

            if epoch_record.should_checkpoint:
                last_cp = self._save()
                if verbose:
                    print(f"<<<< Checkpointed ({last_cp}) >>>")
            if epoch_record.done:
                break
            if epoch_record.epoch >= max_epochs:
                break

        # Save trainer state, but not model"
        self._save(save_model=False)
        if verbose:
            print(self.model_path)

    def print_train_summary(self):
        df = self.train_state.to_df()

        with pd.option_context(
            "display.max_rows",
            100,
            "display.max_columns",
            100,
            "display.precision",
            3,
            "display.width",
            180,
        ):
            print(df.drop(["done"], axis=1, errors="ignore"))

    def _save(self, checkpoint_dir=None, save_model=True, save_trainer=True):
        """ Saves/checkpoints model state and training state to disk. """
        if checkpoint_dir is None:
            checkpoint_dir = self.model_path
        else:
            self.model_path = checkpoint_dir

        os.makedirs(checkpoint_dir, exist_ok=True)

        # save model params
        model_path = os.path.join(checkpoint_dir, "model.pth")
        trainer_path = os.path.join(checkpoint_dir, "trainer.pth")

        if save_model:
            torch.save(self.model.state_dict(), model_path)
        if save_trainer:
            with open(trainer_path, "wb") as f:
                pickle.dump(self, f)

        return checkpoint_dir

    def _restore(self, checkpoint_dir=None):
        """ Restores model state and training state from disk. """

        if checkpoint_dir is None:
            checkpoint_dir = self.model_path

        model_path = os.path.join(checkpoint_dir, "model.pth")
        trainer_path = os.path.join(checkpoint_dir, "trainer.pth")

        # Reconstitute old trainer and copy state to this trainer.
        with open(trainer_path, "rb") as f:
            other_trainer = pickle.load(f)

        self.__setstate__(other_trainer.__getstate__())

        # Load model (after loading state in case we need to re-initialize model from config)
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))

        # Be careful to reinitialize optimizer and lr scheduler
        self.optimizer = self._optimizer_default()
        self.lr_scheduler = self._lr_scheduler_default()
예제 #17
0
class BCSliceI(bu.Model):
    '''
    Implements the IBC functionality for a constrained dof.
    '''
    name = tr.Str('<unnamed>')

    var = tr.Enum('u', 'f')

    slice = tr.Instance(FEGridNodeSlice)
    link_slice = tr.Instance(FEGridNodeSlice)

    bcdof_list = tr.List(BCDof)

    def reset(self):
        self.bcdof_list = []

    link_coeffs = tr.List(tr.Float)
    '''
    List of dofs that determine the value of the current dof

    If this list is empty, then the current dof is
    prescribed. Otherwise, the dof value is given by the
    linear combination of DOFs in the list (see the example below)

    link_dofs = List( Int )

    Coefficients of the linear combination of DOFs specified in the
    above list.
    '''

    dims = tr.List(tr.Int)

    _link_dims = tr.List(tr.Int)
    link_dims = tr.Property(tr.List(tr.Int))

    def _get_link_dims(self):
        if len(self._link_dims) == 0:
            return self.dims
        else:
            return self._link_dims

    def _set_link_dims(self, link_dims):
        self._link_dims = link_dims

    value = tr.Float

    time_function = tr.Instance(TimeFunction, ())

    def _time_function_default(self):
        return TFMonotonic()

    space_function = tr.Instance(MFnLineArray, ())

    def _space_function_default(self):
        return MFnLineArray(xdata=[0, 1], ydata=[1, 1], extrapolate='diff')

    def is_essential(self):
        return self.var == 'u'

    def is_linked(self):
        return self.link_dofs != []

    def is_constrained(self):
        '''
        Return true if a DOF is either explicitly prescribed or it depends on other DOFS.
        '''
        return self.is_essential() or self.is_linked()

    def is_natural(self):
        return self.var == 'f'

    def setup(self, sctx):
        '''
        Locate the spatial context.f
        '''
        if self.link_slice == None:
            for node_dofs, dof_X in zip(self.slice.dofs, self.slice.dof_X):
                for dof in node_dofs[self.dims]:
                    self.bcdof_list.append(
                        BCDof(
                            var=self.var,
                            dof=dof,
                            value=self.value,
                            # link_coeffs=self.link_coeffs,
                            time_function=self.time_function))
        else:
            # apply the linked slice
            n_link_nodes = len(self.link_slice.dofs.flatten())
            link_dofs = self.link_dofs
            if n_link_nodes == 1:
                #
                link_dof = self.link_slice.dofs.flatten()[0]
                link_coeffs = self.link_coeffs
                for node_dofs, dof_X in zip(self.slice.dofs, self.slice.dof_X):
                    for dof, link_dof, link_coeff in zip(
                            node_dofs[self.dims], link_dofs, link_coeffs):
                        self.bcdof_list.append(
                            BCDof(var=self.var,
                                  dof=dof,
                                  link_dofs=[link_dof],
                                  value=self.value,
                                  link_coeffs=[link_coeff],
                                  time_function=self.time_function))
            else:
                for node_dofs, dof_X, node_link_dofs, link_dof_X in \
                    zip(self.slice.dofs, self.slice.dof_X,
                        self.link_slice.dofs, self.link_slice.dof_X):
                    #print('node', node_dofs, node_link_dofs)
                    #print('node[dims]', node_dofs[self.dims],
                    # node_link_dofs[self.link_dims])
                    for dof, link_dof, link_coeff in zip(
                            node_dofs[self.dims],
                            node_link_dofs[self.link_dims], self.link_coeffs):
                        #print('dof, link, coeff', dof, link_dof, link_coeff)
                        self.bcdof_list.append(
                            BCDof(var=self.var,
                                  dof=dof,
                                  link_dofs=[link_dof],
                                  value=self.value,
                                  link_coeffs=[link_coeff],
                                  time_function=self.time_function))

    def register(self, K):
        '''Register the boundary condition in the equation system.
        '''
        for bcond in self.bcdof_list:
            bcond.register(K)

    def apply_essential(self, K):

        for bcond in self.bcdof_list:
            bcond.apply_essential(K)

    def apply(self, step_flag, sctx, K, R, t_n, t_n1):

        for bcond in self.bcdof_list:
            bcond.apply(step_flag, sctx, K, R, t_n, t_n1)

    #-------------------------------------------------------------------------
    # Ccnstrained DOFs
    #-------------------------------------------------------------------------

    dofs = tr.Property

    def _get_dofs(self):
        return np.unique(self.slice.dofs[..., self.dims].flatten())

    dof_X = tr.Property

    def _get_dof_X(self):
        return self.slice.dof_X

    n_dof_nodes = tr.Property

    def _get_n_dof_nodes(self):
        sliceshape = self.dofs.shape
        return sliceshape[0] * sliceshape[1]

    #-------------------------------------------------------------------------
    # Link DOFs
    #-------------------------------------------------------------------------
    link_dofs = tr.Property(tr.List)

    def _get_link_dofs(self):
        if self.link_slice != None:
            return np.unique(self.link_slice.dofs[...,
                                                  self.link_dims].flatten())
        else:
            return []

    link_dof_X = tr.Property

    def _get_link_dof_X(self):
        return self.link_slice.dof_X

    n_link_dof_nodes = tr.Property

    def _get_n_link_dof_nodes(self):
        sliceshape = self.link_dofs.shape
        return sliceshape[0] * sliceshape[1]
예제 #18
0
class TStep(bu.Model):
    '''Manage the data and metadata of a time step within an interation loop.
    '''
    title = tr.Str('<unnamed>')

    tloop_type = tr.Type(ITLoop)
    '''Type of time loop to be used with the model
    '''

    #=========================================================================
    # HISTORY
    #=========================================================================
    hist_type = tr.Type(Hist)

    hist = tr.Property(tr.Instance(IHist))
    r'''History representation of the model response.
    '''
    @tr.cached_property
    def _get_hist(self):
        return self.hist_type(tstep_source=self)

    debug = tr.Bool(False)

    t_n1 = tr.Float(0.0, auto_set=False, enter_set=True)
    '''Target value of the control variable.
    '''
    U_n = tr.Float(0.0, auto_set=False, enter_set=True)
    '''Current fundamental value of the primary variable.
    '''
    U_k = tr.Float(0.0, auto_set=False, enter_set=True)
    '''Current trial value of the primary variable.
    '''

    def init_state(self):
        '''Initialize state.
        '''
        self.U_n = 0.0
        self.t_n1 = 0.0
        self.U_k = 0.0

    def record_state(self):
        '''Provide the current state for history recording.
        '''
        pass

    _corr_pred = tr.Property(depends_on='U_k,t_n1')

    @tr.cached_property
    def _get__corr_pred(self):
        return self.get_corr_pred(self.U_k, self.t_n1)

    R = tr.Property

    def _get_R(self):
        R, _ = self._corr_pred
        return R

    dR = tr.Property

    def _get_dR(self):
        _, dR = self._corr_pred
        return dR

    R_norm = tr.Property

    def _get_R_norm(self):
        R = self.R
        return np.sqrt(R * R)

    def make_iter(self):
        d_U = self.R / self.dR
        self.U_k += d_U

    def make_incr(self):
        '''Update the control, primary and state variables..
        '''
        self.U_n = self.U_k
        # self.hist.record_timestep()

    sim = tr.Property()
    '''Launch a simulator - currently only one simulator is allowed
    for a model. Mutiple might also make sense when different solvers
    are to be compared. The simulator pulls the time loop type
    from the model.
    '''
    @tr.cached_property
    def _get_sim(self):
        return Simulator(self)
예제 #19
0
class Vis2DField(Vis2D):

    model = tr.DelegatesTo('sim')

    x_file = tr.File
    file_list = tr.List(tr.File)

    var = tr.Str('<unnamed>')

    dir = tr.Directory

    def new_dir(self):
        self.dir = tempfile.mkdtemp()

    def setup(self):
        self.new_dir()
        # make a loop over the DomainState
        fe_domain = self.sim.tstep.fe_domain
        domain = fe_domain[2]
        xdomain = domain.xmodel
        r_Eia = np.einsum(
            'Eira,Eia->Eir',
            xdomain.T_Emra[..., :xdomain.x_Eia.shape[-1]], xdomain.x_Eia
        )

        file_name = 'slice_x_%s' % (self.var,)
        target_file = os.path.join(
            self.dir, file_name.replace('.', '_') + '.npy'
        )
        #print('r', r_Eia[..., :-1])
        np.save(target_file, r_Eia[..., :-1])
        self.x_file = target_file

    def get_x_Eir(self):
        return np.load(self.x_file)

    def update(self):
        ts = self.sim.tstep
        fe_domain = self.sim.tstep.fe_domain
        domain = fe_domain[2]
        xdomain = domain.xmodel
        U = ts.U_k
        t = ts.t_n1
        s_Emr = xdomain.map_U_to_field(U)

        var_function = domain.tmodel.var_dict.get(self.var, None)
        if var_function == None:
            raise ValueError('no such variable' % self.var)

        state_k = copy.deepcopy(domain.state_n)
        var_k = var_function(s_Emr, ts.t_n1, **state_k)

        target_file = self.filename(t)

        #np.save(target_file, s_Emr)
        np.save(target_file, var_k)
        self.file_list.append(target_file)

    def filename(self, t):
        file_name = 'slice_%s_step_%008.4f' % (self.var, t)
        target_file = os.path.join(
            self.dir, file_name.replace('.', '_')
        ) + '.npy'
        return target_file
예제 #20
0
파일: spm.py 프로젝트: servoz/capsul
def edition_widget(engine, environment):
    ''' Edition GUI for SPM config - see
    :class:`~capsul.qt_gui.widgets.settings_editor.SettingsEditor`
    '''
    from soma.qt_gui.controller_widget import ScrollControllerWidget
    from soma.controller import Controller
    import types
    import traits.api as traits

    def validate_config(widget):
        controller = widget.controller_widget.controller
        with widget.engine.settings as session:
            values = {}
            if controller.directory in (None, traits.Undefined, ''):
                values['directory'] = None
            else:
                values['directory'] = controller.directory
            values['standalone'] = controller.standalone
            values['version'] = controller.version
            id = 'spm%s%s' % (controller.version,
                              '-standalone' if controller.standalone else '')
            values['config_id'] = id
            query = 'config_id == "%s"' % id
            conf = session.config('spm', 'global', selection=query)
            if conf is None:
                session.new_config('spm', widget.environment, values)
            else:
                for k in ('directory', 'standalone', 'version'):
                    setattr(conf, k, values[k])

    controller = Controller()
    controller.add_trait(
        "directory",
        traits.Directory(traits.Undefined,
                         output=False,
                         desc="Directory containing SPM."))
    controller.add_trait(
        "standalone",
        traits.Bool(True, desc="If True, use the standalone version of SPM."))
    controller.add_trait(
        'version',
        traits.Str(traits.Undefined,
                   output=False,
                   desc='Version string for SPM: "12", "8", etc.'))

    conf = engine.settings.select_configurations(environment, {'spm': 'any'})
    if conf:
        controller.directory = conf.get('capsul.engine.module.spm',
                                        {}).get('directory', traits.Undefined)
        controller.standalone = conf.get('capsul.engine.module.spm',
                                         {}).get('standalone', True)
        controller.version = conf.get('capsul.engine.module.spm',
                                      {}).get('version', '12')

    # TODO handle several configs

    widget = ScrollControllerWidget(controller, live=True)
    widget.engine = engine
    widget.environment = environment
    widget.accept = types.MethodType(validate_config, widget)

    return widget
예제 #21
0
class DataAxis(t.HasTraits):
    name = t.Str()
    units = t.Str()
    scale = t.Float()
    offset = t.Float()
    size = t.CInt()
    low_value = t.Float()
    high_value = t.Float()
    value = t.Range('low_value', 'high_value')
    low_index = t.Int(0)
    high_index = t.Int()
    slice = t.Instance(slice)
    navigate = t.Bool(t.Undefined)
    index = t.Range('low_index', 'high_index')
    axis = t.Array()
    continuous_value = t.Bool(False)

    def __init__(self,
                 size,
                 index_in_array=None,
                 name=t.Undefined,
                 scale=1.,
                 offset=0.,
                 units=t.Undefined,
                 navigate=t.Undefined):
        super(DataAxis, self).__init__()
        self.events = Events()
        self.events.index_changed = Event("""
            Event that triggers when the index of the `DataAxis` changes

            Triggers after the internal state of the `DataAxis` has been
            updated.

            Arguments:
            ---------
            obj : The DataAxis that the event belongs to.
            index : The new index
            """,
                                          arguments=["obj", 'index'])
        self.events.value_changed = Event("""
            Event that triggers when the value of the `DataAxis` changes

            Triggers after the internal state of the `DataAxis` has been
            updated.

            Arguments:
            ---------
            obj : The DataAxis that the event belongs to.
            value : The new value
            """,
                                          arguments=["obj", 'value'])
        self._suppress_value_changed_trigger = False
        self._suppress_update_value = False
        self.name = name
        self.units = units
        self.scale = scale
        self.offset = offset
        self.size = size
        self.high_index = self.size - 1
        self.low_index = 0
        self.index = 0
        self.update_axis()
        self.navigate = navigate
        self.axes_manager = None
        self.on_trait_change(self.update_axis, ['scale', 'offset', 'size'])
        self.on_trait_change(self._update_slice, 'navigate')
        self.on_trait_change(self.update_index_bounds, 'size')
        # The slice must be updated even if the default value did not
        # change to correctly set its value.
        self._update_slice(self.navigate)

    def _index_changed(self, name, old, new):
        self.events.index_changed.trigger(obj=self, index=self.index)
        if not self._suppress_update_value:
            new_value = self.axis[self.index]
            if new_value != self.value:
                self.value = new_value

    def _value_changed(self, name, old, new):
        old_index = self.index
        new_index = self.value2index(new)
        if self.continuous_value is False:  # Only values in the grid allowed
            if old_index != new_index:
                self.index = new_index
                if new == self.axis[self.index]:
                    self.events.value_changed.trigger(obj=self, value=new)
            elif old_index == new_index:
                new_value = self.index2value(new_index)
                if new_value == old:
                    self._suppress_value_changed_trigger = True
                    try:
                        self.value = new_value
                    finally:
                        self._suppress_value_changed_trigger = False

                elif new_value == new and not\
                        self._suppress_value_changed_trigger:
                    self.events.value_changed.trigger(obj=self, value=new)
        else:  # Intergrid values are allowed. This feature is deprecated
            self.events.value_changed.trigger(obj=self, value=new)
            if old_index != new_index:
                self._suppress_update_value = True
                self.index = new_index
                self._suppress_update_value = False

    @property
    def index_in_array(self):
        if self.axes_manager is not None:
            return self.axes_manager._axes.index(self)
        else:
            raise AttributeError(
                "This DataAxis does not belong to an AxesManager"
                " and therefore its index_in_array attribute "
                " is not defined")

    @property
    def index_in_axes_manager(self):
        if self.axes_manager is not None:
            return self.axes_manager._get_axes_in_natural_order().\
                index(self)
        else:
            raise AttributeError(
                "This DataAxis does not belong to an AxesManager"
                " and therefore its index_in_array attribute "
                " is not defined")

    def _get_positive_index(self, index):
        if index < 0:
            index = self.size + index
            if index < 0:
                raise IndexError("index out of bounds")
        return index

    def _get_index(self, value):
        if isfloat(value):
            return self.value2index(value)
        else:
            return value

    def _get_array_slices(self, slice_):
        """Returns a slice to slice the corresponding data axis without
        changing the offset and scale of the DataAxis.

        Parameters
        ----------
        slice_ : {float, int, slice}

        Returns
        -------
        my_slice : slice

        """
        v2i = self.value2index

        if isinstance(slice_, slice):
            start = slice_.start
            stop = slice_.stop
            step = slice_.step
        else:
            if isfloat(slice_):
                start = v2i(slice_)
            else:
                start = self._get_positive_index(slice_)
            stop = start + 1
            step = None

        if isfloat(step):
            step = int(round(step / self.scale))
        if isfloat(start):
            try:
                start = v2i(start)
            except ValueError:
                if start > self.high_value:
                    # The start value is above the axis limit
                    raise IndexError(
                        "Start value above axis high bound for  axis %s."
                        "value: %f high_bound: %f" %
                        (repr(self), start, self.high_value))
                else:
                    # The start value is below the axis limit,
                    # we slice from the start.
                    start = None
        if isfloat(stop):
            try:
                stop = v2i(stop)
            except ValueError:
                if stop < self.low_value:
                    # The stop value is below the axis limits
                    raise IndexError(
                        "Stop value below axis low bound for  axis %s."
                        "value: %f low_bound: %f" %
                        (repr(self), stop, self.low_value))
                else:
                    # The stop value is below the axis limit,
                    # we slice until the end.
                    stop = None

        if step == 0:
            raise ValueError("slice step cannot be zero")

        return slice(start, stop, step)

    def _slice_me(self, slice_):
        """Returns a slice to slice the corresponding data axis and
        change the offset and scale of the DataAxis accordingly.

        Parameters
        ----------
        slice_ : {float, int, slice}

        Returns
        -------
        my_slice : slice

        """
        i2v = self.index2value

        my_slice = self._get_array_slices(slice_)

        start, stop, step = my_slice.start, my_slice.stop, my_slice.step

        if start is None:
            if step is None or step > 0:
                start = 0
            else:
                start = self.size - 1
        self.offset = i2v(start)
        if step is not None:
            self.scale *= step

        return my_slice

    def _get_name(self):
        if self.name is t.Undefined:
            if self.axes_manager is None:
                name = "Unnamed"
            else:
                name = "Unnamed " + ordinal(self.index_in_axes_manager)
        else:
            name = self.name
        return name

    def __repr__(self):
        text = '<%s axis, size: %i' % (
            self._get_name(),
            self.size,
        )
        if self.navigate is True:
            text += ", index: %i" % self.index
        text += ">"
        return text

    def __str__(self):
        return self._get_name() + " axis"

    def update_index_bounds(self):
        self.high_index = self.size - 1

    def update_axis(self):
        self.axis = generate_axis(self.offset, self.scale, self.size)
        if len(self.axis) != 0:
            self.low_value, self.high_value = (self.axis.min(),
                                               self.axis.max())

    def _update_slice(self, value):
        if value is False:
            self.slice = slice(None)
        else:
            self.slice = None

    def get_axis_dictionary(self):
        adict = {
            'name': self.name,
            'scale': self.scale,
            'offset': self.offset,
            'size': self.size,
            'units': self.units,
            'navigate': self.navigate
        }
        return adict

    def copy(self):
        return DataAxis(**self.get_axis_dictionary())

    def __copy__(self):
        return self.copy()

    def __deepcopy__(self, memo):
        cp = self.copy()
        return cp

    def value2index(self, value, rounding=round):
        """Return the closest index to the given value if between the limit.

        Parameters
        ----------
        value : number or numpy array

        Returns
        -------
        index : integer or numpy array

        Raises
        ------
        ValueError if any value is out of the axis limits.

        """
        if value is None:
            return None

        if isinstance(value, np.ndarray):
            if rounding is round:
                rounding = np.round
            elif rounding is math.ceil:
                rounding = np.ceil
            elif rounding is math.floor:
                rounding = np.floor

        index = rounding((value - self.offset) / self.scale)

        if isinstance(value, np.ndarray):
            index = index.astype(int)
            if np.all(self.size > index) and np.all(index >= 0):
                return index
            else:
                raise ValueError("A value is out of the axis limits")
        else:
            index = int(index)
            if self.size > index >= 0:
                return index
            else:
                raise ValueError("The value is out of the axis limits")

    def index2value(self, index):
        if isinstance(index, np.ndarray):
            return self.axis[index.ravel()].reshape(index.shape)
        else:
            return self.axis[index]

    def calibrate(self, value_tuple, index_tuple, modify_calibration=True):
        scale = (value_tuple[1] - value_tuple[0]) /\
            (index_tuple[1] - index_tuple[0])
        offset = value_tuple[0] - scale * index_tuple[0]
        if modify_calibration is True:
            self.offset = offset
            self.scale = scale
        else:
            return offset, scale

    def value_range_to_indices(self, v1, v2):
        """Convert the given range to index range.

        When an out of the axis limits, the endpoint is used instead.

        Parameters
        ----------
        v1, v2 : float
            The end points of the interval in the axis units. v2 must be
            greater than v1.

        """
        if v1 is not None and v2 is not None and v1 > v2:
            raise ValueError("v2 must be greater than v1.")

        if v1 is not None and self.low_value < v1 <= self.high_value:
            i1 = self.value2index(v1)
        else:
            i1 = 0
        if v2 is not None and self.high_value > v2 >= self.low_value:
            i2 = self.value2index(v2)
        else:
            i2 = self.size - 1
        return i1, i2

    def update_from(self, axis, attributes=["scale", "offset", "units"]):
        """Copy values of specified axes fields from the passed AxesManager.

        Parameters
        ----------
        axis : DataAxis
            The DataAxis instance to use as a source for values.
        attributes : iterable container of strings.
            The name of the attribute to update. If the attribute does not
            exist in either of the AxesManagers, an AttributeError will be
            raised.
        Returns
        -------
        A boolean indicating whether any changes were made.

        """
        any_changes = False
        changed = {}
        for f in attributes:
            if getattr(self, f) != getattr(axis, f):
                changed[f] = getattr(axis, f)
        if len(changed) > 0:
            self.trait_set(**changed)
            any_changes = True
        return any_changes
예제 #22
0
class config(HasTraits):
    uuid = traits.Str(desc="UUID")

    # Directories
    working_dir = Directory(mandatory=True, desc="Location of the Nipype working directory")
    base_dir = Directory(os.path.abspath('.'),exists=True, desc='Base directory of data. (Should be subject-independent)')
    sink_dir = Directory(mandatory=True, desc="Location where the BIP will store the results")
    crash_dir = Directory(mandatory=False, desc="Location to store crash files")
    surf_dir = Directory(os.path.abspath('.'),mandatory=True, desc="Freesurfer subjects directory")

    # Execution
    run_using_plugin = Bool(False, usedefault=True, desc="True to run pipeline with plugin, False to run serially")
    plugin = traits.Enum("PBS", "MultiProc", "SGE", "Condor",
        usedefault=True,
        desc="plugin to use, if run_using_plugin=True")
    plugin_args = traits.Dict({"qsub_args": "-q many"},
        usedefault=True, desc='Plugin arguments.')
    test_mode = Bool(False, mandatory=False, usedefault=True,
        desc='Affects whether where and if the workflow keeps its \
                            intermediary files. True to keep intermediary files. ')
    # Subjects
    datagrabber = traits.Instance(Data, ())
    subjects = traits.List(traits.Str, mandatory=True, usedefault=True,
        desc="Subject id's. Note: These MUST match the subject id's in the \
                                Freesurfer directory. For simplicity, the subject id's should \
                                also match with the location of individual functional files.")
    func_template = traits.String('%s/cleaned_resting.nii.gz')
    reg_template = traits.String('%s/cleaned_resting_reg.dat')
    ref_template = traits.String('%s/cleaned_resting_ref.nii.gz')
    combine_surfaces = traits.Bool() 

    # Target surface
    target_surf = traits.Enum('fsaverage4', 'fsaverage3', 'fsaverage5',
                              'fsaverage6', 'fsaverage', 'subject',
                              desc='which average surface to map to')
    surface_fwhm = traits.List([5], traits.Float(), mandatory=True,
        usedefault=True,
        desc="How much to smooth on target surface")
    projection_stem = traits.Str('-projfrac-avg 0 1 0.1',
                                 desc='how to project data onto the surface')
    combine_surfaces = traits.Bool(desc=('compute correlation matrix across'
                                         'both left and right surfaces'))

    # Saving output
    out_type = traits.Enum('mat', 'hdf5', desc='mat or hdf5')
    hdf5_package = traits.Enum('h5py', 'pytables',
        desc='which hdf5 package to use')
    # Advanced Options
    use_advanced_options = traits.Bool()
    advanced_script = traits.Code()
    save_script_only = traits.Bool(False)

    # Atlas mapping
    surface_atlas = traits.Str('None',
                               desc='Name of parcellation atlas')

    # Buttons
    check_func_datagrabber = Button("Check")

    def _check_func_datagrabber_fired(self):
        subs = self.subjects
        for s in subs:
            for template in [self.func_template, self.ref_template,
                             self.reg_template]:
                check_path(os.path.join(self.base_dir, template % s))
            check_path(os.path.join(self.surf_dir, s))
예제 #23
0
class config(HasTraits):
    uuid = traits.Str(desc="UUID")

    # Directories
    base_dir = Directory(
        os.path.abspath('.'),
        exists=True,
        desc='Base directory of data. (Should be subject-independent)')
    sink_dir = Directory(mandatory=True,
                         desc="Location where the BIP will store the results")
    surf_dir = Directory(os.path.abspath('.'),
                         mandatory=True,
                         desc="Freesurfer subjects directory")

    # Subjects
    datagrabber = traits.Instance(Data, ())
    subjects = traits.List(
        traits.Str,
        mandatory=True,
        usedefault=True,
        desc="Subject id's. Note: These MUST match the subject id's in the \
                                Freesurfer directory. For simplicity, the subject id's should \
                                also match with the location of individual functional files."
    )
    func_template = traits.String('%s/cleaned_resting.nii.gz')
    reg_template = traits.String('%s/cleaned_resting_reg.dat')
    ref_template = traits.String('%s/cleaned_resting_ref.nii.gz')
    combine_surfaces = traits.Bool()

    # Target surface
    target_surf = traits.Enum('fsaverage4',
                              'fsaverage3',
                              'fsaverage5',
                              'fsaverage6',
                              'fsaverage',
                              'subject',
                              desc='which average surface to map to')
    surface_fwhm = traits.List([5],
                               traits.Float(),
                               mandatory=True,
                               usedefault=True,
                               desc="How much to smooth on target surface")
    projection_stem = traits.Str('-projfrac-avg 0 1 0.1',
                                 desc='how to project data onto the surface')
    combine_surfaces = traits.Bool(desc=('compute correlation matrix across'
                                         'both left and right surfaces'))

    # Saving output
    out_type = traits.Enum('mat', 'hdf5', desc='mat or hdf5')
    hdf5_package = traits.Enum('h5py',
                               'pytables',
                               desc='which hdf5 package to use')
    # Advanced Options
    use_advanced_options = traits.Bool()
    advanced_script = traits.Code()
    save_script_only = traits.Bool(False)

    # Atlas mapping
    surface_atlas = traits.Str('None', desc='Name of parcellation atlas')

    # Buttons
    check_func_datagrabber = Button("Check")

    def _check_func_datagrabber_fired(self):
        subs = self.subjects
        for s in subs:
            for template in [
                    self.func_template, self.ref_template, self.reg_template
            ]:
                check_path(os.path.join(self.base_dir, template % s))
            check_path(os.path.join(self.surf_dir, s))
예제 #24
0
 def configure_controller(cls):
     c = Controller()
     c.add_trait('param_type', traits.Str('Str'))
     c.add_trait('is_output', traits.Bool(True))
     return c
예제 #25
0
class Parameter(t.HasTraits):
    """Model parameter

    Attributes
    ----------
    value : float or array
        The value of the parameter for the current location. The value
        for other locations is stored in map.
    bmin, bmax: float
        Lower and upper bounds of the parameter value.
    twin : {None, Parameter}
        If it is not None, the value of the current parameter is
        a function of the given Parameter. The function is by default
        the identity function, but it can be defined by twin_function
    twin_function_expr: str
        Expression of the ``twin_function`` that enables setting a functional
        relationship between the parameter and its twin. If ``twin`` is not
        ``None``, the parameter value is calculated as the output of calling the
        twin function with the value of the twin parameter. The string is
        parsed using sympy, so permitted values are any valid sympy expressions
        of one variable. If the function is invertible the twin inverse function
        is set automatically.
    twin_inverse_function : str
        Expression of the ``twin_inverse_function`` that enables setting the
        value of the twin parameter. If ``twin`` is not
        ``None``, its value is set to the output of calling the
        twin inverse function with the value provided. The string is
        parsed using sympy, so permitted values are any valid sympy expressions
        of one variable.
    twin_function : function
        **Setting this attribute manually
        is deprecated in HyperSpy newer than 1.1.2. It will become private in
        HyperSpy 2.0. Please use ``twin_function_expr`` instead.**
    twin_inverse_function : function
        **Setting this attribute manually
        is deprecated in HyperSpy newer than 1.1.2. It will become private in
        HyperSpy 2.0. Please use ``twin_inverse_function_expr`` instead.**
    ext_force_positive : bool
        If True, the parameter value is set to be the absolute value
        of the input value i.e. if we set Parameter.value = -3, the
        value stored is 3 instead. This is useful to bound a value
        to be positive in an optimization without actually using an
        optimizer that supports bounding.
    ext_bounded : bool
        Similar to ext_force_positive, but in this case the bounds are
        defined by bmin and bmax. It is a better idea to use
        an optimizer that supports bounding though.

    Methods
    -------
    as_signal(field = 'values')
        Get a parameter map as a signal object
    plot()
        Plots the value of the Parameter at all locations.
    export(folder=None, name=None, format=None, save_std=False)
        Saves the value of the parameter map to the specified format
    connect, disconnect(function)
        Call the functions connected when the value attribute changes.

    """
    __number_of_elements = 1
    __value = 0
    __free = True
    _bounds = (None, None)
    __twin = None
    _axes_manager = None
    __ext_bounded = False
    __ext_force_positive = False

    # traitsui bugs out trying to make an editor for this, so always specify!
    # (it bugs out, because both editor shares the object, and Array editors
    # don't like non-sequence objects). TextEditor() works well, so does
    # RangeEditor() as it works with bmin/bmax.
    value = t.Property(t.Either([t.CFloat(0), Array()]))

    units = t.Str('')
    free = t.Property(t.CBool(True))

    bmin = t.Property(NoneFloat(), label="Lower bounds")
    bmax = t.Property(NoneFloat(), label="Upper bounds")
    _twin_function_expr = ""
    _twin_inverse_function_expr = ""
    twin_function = None
    _twin_inverse_function = None
    _twin_inverse_sympy = None

    def __init__(self):
        self._twins = set()
        self.events = Events()
        self.events.value_changed = Event("""
            Event that triggers when the `Parameter.value` changes.

            The event triggers after the internal state of the `Parameter` has
            been updated.

            Arguments
            ---------
            obj : Parameter
                The `Parameter` that the event belongs to
            value : {float | array}
                The new value of the parameter
            """,
                                          arguments=["obj", 'value'])
        self.std = None
        self.component = None
        self.grad = None
        self.name = ''
        self.units = ''
        self.map = None
        self.model = None
        self._whitelist = {
            '_id_name': None,
            'value': None,
            'std': None,
            'free': None,
            'units': None,
            'map': None,
            '_bounds': None,
            'ext_bounded': None,
            'name': None,
            'ext_force_positive': None,
            'twin_function_expr': None,
            'twin_inverse_function_expr': None,
            'self': ('id', None),
        }
        self._slicing_whitelist = {'map': 'inav'}

    def _load_dictionary(self, dictionary):
        """Load data from dictionary

        Parameters
        ----------
        dict : dictionary
            A dictionary containing at least the following items:
            _id_name : string
                _id_name of the original parameter, used to create the
                dictionary. Has to match with the self._id_name
            _whitelist : dictionary
                a dictionary, which keys are used as keywords to match with the
                parameter attributes.  For more information see
                :meth:`hyperspy.misc.export_dictionary.load_from_dictionary`
            * any field from _whitelist.keys() *
        Returns
        -------
        id_value : int
            the ID value of the original parameter, to be later used for setting
            up the correct twins

        """
        if dictionary['_id_name'] == self._id_name:
            load_from_dictionary(self, dictionary)
            return dictionary['self']
        else:
            raise ValueError(
                "_id_name of parameter and dictionary do not match, \nparameter._id_name = %s\
                    \ndictionary['_id_name'] = %s" %
                (self._id_name, dictionary['_id_name']))

    def __repr__(self):
        text = ''
        text += 'Parameter %s' % self.name
        if self.component is not None:
            text += ' of %s' % self.component._get_short_description()
        text = '<' + text + '>'
        return text

    def __len__(self):
        return self._number_of_elements

    @property
    def twin_function_expr(self):
        return self._twin_function_expr

    @twin_function_expr.setter
    def twin_function_expr(self, value):
        if not value:
            self.twin_function = None
            self.twin_inverse_function = None
            self._twin_function_expr = ""
            self._twin_inverse_sympy = None
            return
        expr = sympy.sympify(value)
        if len(expr.free_symbols) > 1:
            raise ValueError("The expression must contain only one variable.")
        elif len(expr.free_symbols) == 0:
            raise ValueError("The expression must contain one variable, "
                             "it contains none.")
        x = tuple(expr.free_symbols)[0]
        self.twin_function = lambdify(x, expr.evalf())
        self._twin_function_expr = value
        if not self.twin_inverse_function:
            y = sympy.Symbol(x.name + "2")
            try:
                inv = sympy.solveset(sympy.Eq(y, expr), x)
                self._twin_inverse_sympy = lambdify(y, inv)
                self._twin_inverse_function = None
            except:
                # Not all may have a suitable solution.
                self._twin_inverse_function = None
                self._twin_inverse_sympy = None
                _logger.warning(
                    "The function {} is not invertible. Setting the value of "
                    "{} will raise an AttributeError unless you set manually "
                    "``twin_inverse_function_expr``. Otherwise, set the "
                    "value of its twin parameter instead.".format(value, self))

    @property
    def twin_inverse_function_expr(self):
        if self.twin:
            return self._twin_inverse_function_expr
        else:
            return ""

    @twin_inverse_function_expr.setter
    def twin_inverse_function_expr(self, value):
        if not value:
            self.twin_inverse_function = None
            self._twin_inverse_function_expr = ""
            return
        expr = sympy.sympify(value)
        if len(expr.free_symbols) > 1:
            raise ValueError("The expression must contain only one variable.")
        elif len(expr.free_symbols) == 0:
            raise ValueError("The expression must contain one variable, "
                             "it contains none.")
        x = tuple(expr.free_symbols)[0]
        self._twin_inverse_function = lambdify(x, expr.evalf())
        self._twin_inverse_function_expr = value

    @property
    def twin_inverse_function(self):
        if (not self.twin_inverse_function_expr and self.twin_function_expr
                and self._twin_inverse_sympy):
            return lambda x: self._twin_inverse_sympy(x).pop()
        else:
            return self._twin_inverse_function

    @twin_inverse_function.setter
    def twin_inverse_function(self, value):
        self._twin_inverse_function = value

    def _get_value(self):
        if self.twin is None:
            return self.__value
        else:
            if self.twin_function:
                return self.twin_function(self.twin.value)
            else:
                return self.twin.value

    def _set_value(self, value):
        try:
            # Use try/except instead of hasattr("__len__") because a numpy
            # memmap has a __len__ wrapper even for numbers that raises a
            # TypeError when calling. See issue #349.
            if len(value) != self._number_of_elements:
                raise ValueError("The length of the parameter must be ",
                                 self._number_of_elements)
            else:
                if not isinstance(value, tuple):
                    value = tuple(value)
        except TypeError:
            if self._number_of_elements != 1:
                raise ValueError("The length of the parameter must be ",
                                 self._number_of_elements)
        old_value = self.__value

        if self.twin is not None:
            if self.twin_function is not None:
                if self.twin_inverse_function is not None:
                    self.twin.value = self.twin_inverse_function(value)
                    return
                else:
                    raise AttributeError(
                        "This parameter has a ``twin_function`` but"
                        "its ``twin_inverse_function`` is not defined.")
            else:
                self.twin.value = value
                return

        if self.ext_bounded is False:
            self.__value = value
        else:
            if self.ext_force_positive is True:
                value = np.abs(value)
            if self._number_of_elements == 1:
                if self.bmin is not None and value <= self.bmin:
                    self.__value = self.bmin
                elif self.bmax is not None and value >= self.bmax:
                    self.__value = self.bmax
                else:
                    self.__value = value
            else:
                bmin = (self.bmin if self.bmin is not None else -np.inf)
                bmax = (self.bmax if self.bmin is not None else np.inf)
                self.__value = np.clip(value, bmin, bmax)

        if (self._number_of_elements != 1
                and not isinstance(self.__value, tuple)):
            self.__value = tuple(self.__value)
        if old_value != self.__value:
            self.events.value_changed.trigger(value=self.__value, obj=self)
        self.trait_property_changed('value', old_value, self.__value)

    # Fix the parameter when coupled
    def _get_free(self):
        if self.twin is None:
            return self.__free
        else:
            return False

    def _set_free(self, arg):
        old_value = self.__free
        self.__free = arg
        if self.component is not None:
            self.component._update_free_parameters()
        self.trait_property_changed('free', old_value, self.__free)

    def _on_twin_update(self, value, twin=None):
        if (twin is not None and hasattr(twin, 'events')
                and hasattr(twin.events, 'value_changed')):
            with twin.events.value_changed.suppress_callback(
                    self._on_twin_update):
                self.events.value_changed.trigger(value=value, obj=self)
        else:
            self.events.value_changed.trigger(value=value, obj=self)

    def _set_twin(self, arg):
        if arg is None:
            if self.twin is not None:
                # Store the value of the twin in order to set the
                # value of the parameter when it is uncoupled
                twin_value = self.value
                if self in self.twin._twins:
                    self.twin._twins.remove(self)
                    self.twin.events.value_changed.disconnect(
                        self._on_twin_update)

                self.__twin = arg
                self.value = twin_value
        else:
            if self not in arg._twins:
                arg._twins.add(self)
                arg.events.value_changed.connect(self._on_twin_update,
                                                 ["value"])
            self.__twin = arg

        if self.component is not None:
            self.component._update_free_parameters()

    def _get_twin(self):
        return self.__twin

    twin = property(_get_twin, _set_twin)

    def _get_bmin(self):
        if self._number_of_elements == 1:
            return self._bounds[0]
        else:
            return self._bounds[0][0]

    def _set_bmin(self, arg):
        old_value = self.bmin
        if self._number_of_elements == 1:
            self._bounds = (arg, self.bmax)
        else:
            self._bounds = ((arg, self.bmax), ) * self._number_of_elements
        # Update the value to take into account the new bounds
        self.value = self.value
        self.trait_property_changed('bmin', old_value, arg)

    def _get_bmax(self):
        if self._number_of_elements == 1:
            return self._bounds[1]
        else:
            return self._bounds[0][1]

    def _set_bmax(self, arg):
        old_value = self.bmax
        if self._number_of_elements == 1:
            self._bounds = (self.bmin, arg)
        else:
            self._bounds = ((self.bmin, arg), ) * self._number_of_elements
        # Update the value to take into account the new bounds
        self.value = self.value
        self.trait_property_changed('bmax', old_value, arg)

    @property
    def _number_of_elements(self):
        return self.__number_of_elements

    @_number_of_elements.setter
    def _number_of_elements(self, arg):
        # Do nothing if the number of arguments stays the same
        if self.__number_of_elements == arg:
            return
        if arg <= 1:
            raise ValueError("Please provide an integer number equal "
                             "or greater to 1")
        self._bounds = ((self.bmin, self.bmax), ) * arg
        self.__number_of_elements = arg

        if arg == 1:
            self._Parameter__value = 0
        else:
            self._Parameter__value = (0, ) * arg
        if self.component is not None:
            self.component.update_number_parameters()

    @property
    def ext_bounded(self):
        return self.__ext_bounded

    @ext_bounded.setter
    def ext_bounded(self, arg):
        if arg is not self.__ext_bounded:
            self.__ext_bounded = arg
            # Update the value to take into account the new bounds
            self.value = self.value

    @property
    def ext_force_positive(self):
        return self.__ext_force_positive

    @ext_force_positive.setter
    def ext_force_positive(self, arg):
        if arg is not self.__ext_force_positive:
            self.__ext_force_positive = arg
            # Update the value to take into account the new bounds
            self.value = self.value

    def store_current_value_in_array(self):
        """Store the value and std attributes.

        See also
        --------
        fetch, assign_current_value_to_all

        """
        indices = self._axes_manager.indices[::-1]
        # If it is a single spectrum indices is ()
        if not indices:
            indices = (0, )
        self.map['values'][indices] = self.value
        self.map['is_set'][indices] = True
        if self.std is not None:
            self.map['std'][indices] = self.std

    def fetch(self):
        """Fetch the stored value and std attributes.


        See Also
        --------
        store_current_value_in_array, assign_current_value_to_all

        """
        indices = self._axes_manager.indices[::-1]
        # If it is a single spectrum indices is ()
        if not indices:
            indices = (0, )
        if self.map['is_set'][indices]:
            self.value = self.map['values'][indices]
            self.std = self.map['std'][indices]

    def assign_current_value_to_all(self, mask=None):
        """Assign the current value attribute to all the  indices

        Parameters
        ----------
        mask: {None, boolean numpy array}
            Set only the indices that are not masked i.e. where
            mask is False.

        See Also
        --------
        store_current_value_in_array, fetch

        """
        if mask is None:
            mask = np.zeros(self.map.shape, dtype='bool')
        self.map['values'][mask == False] = self.value
        self.map['is_set'][mask == False] = True

    def _create_array(self):
        """Create the map array to store the information in
        multidimensional datasets.

        """
        shape = self._axes_manager._navigation_shape_in_array
        if not shape:
            shape = [
                1,
            ]
        dtype_ = np.dtype([('values', 'float', self._number_of_elements),
                           ('std', 'float', self._number_of_elements),
                           ('is_set', 'bool', 1)])
        if (self.map is None or self.map.shape != shape
                or self.map.dtype != dtype_):
            self.map = np.zeros(shape, dtype_)
            self.map['std'].fill(np.nan)
            # TODO: in the future this class should have access to
            # axes manager and should be able to fetch its own
            # values. Until then, the next line is necessary to avoid
            # erros when self.std is defined and the shape is different
            # from the newly defined arrays
            self.std = None

    def as_signal(self, field='values'):
        """Get a parameter map as a signal object.

        Please note that this method only works when the navigation
        dimension is greater than 0.

        Parameters
        ----------
        field : {'values', 'std', 'is_set'}

        Raises
        ------

        NavigationDimensionError : if the navigation dimension is 0

        """
        from hyperspy.signal import BaseSignal

        s = BaseSignal(data=self.map[field],
                       axes=self._axes_manager._get_navigation_axes_dicts())
        if self.component is not None and \
                self.component.active_is_multidimensional:
            s.data[np.logical_not(self.component._active_array)] = np.nan

        s.metadata.General.title = ("%s parameter" %
                                    self.name if self.component is None else
                                    "%s parameter of %s component" %
                                    (self.name, self.component.name))
        for axis in s.axes_manager._axes:
            axis.navigate = False
        if self._number_of_elements > 1:
            s.axes_manager._append_axis(size=self._number_of_elements,
                                        name=self.name,
                                        navigate=True)
        s._assign_subclass()
        if field == "values":
            # Add the variance if available
            std = self.as_signal(field="std")
            if not np.isnan(std.data).all():
                std.data = std.data**2
                std.metadata.General.title = "Variance"
                s.metadata.set_item("Signal.Noise_properties.variance", std)
        return s

    def plot(self, **kwargs):
        """Plot parameter signal.

        Parameters
        ----------
        **kwargs
            Any extra keyword arguments are passed to the signal plot.

        Example
        -------
        >>> parameter.plot() #doctest: +SKIP

        Set the minimum and maximum displayed values

        >>> parameter.plot(vmin=0, vmax=1) #doctest: +SKIP
        """
        self.as_signal().plot(**kwargs)

    def export(self, folder=None, name=None, format="hspy", save_std=False):
        """Save the data to a file.

        All the arguments are optional.

        Parameters
        ----------
        folder : str or None
            The path to the folder where the file will be saved.
             If `None` the current folder is used by default.
        name : str or None
            The name of the file. If `None` the Components name followed
             by the Parameter `name` attributes will be used by default.
              If a file with the same name exists the name will be
              modified by appending a number to the file path.
        save_std : bool
            If True, also the standard deviation will be saved
        format: str
            The extension of any file format supported by HyperSpy, default hspy

        """
        if format is None:
            format = "hspy"
        if name is None:
            name = self.component.name + '_' + self.name
        filename = incremental_filename(slugify(name) + '.' + format)
        if folder is not None:
            filename = os.path.join(folder, filename)
        self.as_signal().save(filename)
        if save_std is True:
            self.as_signal(field='std').save(append2pathname(filename, '_std'))

    def as_dictionary(self, fullcopy=True):
        """Returns parameter as a dictionary, saving all attributes from
        self._whitelist.keys() For more information see
        :meth:`hyperspy.misc.export_dictionary.export_to_dictionary`

        Parameters
        ----------
        fullcopy : Bool (optional, False)
            Copies of objects are stored, not references. If any found,
            functions will be pickled and signals converted to dictionaries
        Returns
        -------
        dic : dictionary with the following keys:
            _id_name : string
                _id_name of the original parameter, used to create the
                dictionary. Has to match with the self._id_name
            _twins : list
                a list of ids of the twins of the parameter
            _whitelist : dictionary
                a dictionary, which keys are used as keywords to match with the
                parameter attributes.  For more information see
                :meth:`hyperspy.misc.export_dictionary.export_to_dictionary`
            * any field from _whitelist.keys() *

        """
        dic = {'_twins': [id(t) for t in self._twins]}
        export_to_dictionary(self, self._whitelist, dic, fullcopy)
        return dic

    def default_traits_view(self):
        # As mentioned above, the default editor for
        # value = t.Property(t.Either([t.CFloat(0), Array()]))
        # gives a ValueError. We therefore implement default_traits_view so
        # that configure/edit_traits will still work straight out of the box.
        # A whitelist controls which traits to include in this view.
        from traitsui.api import RangeEditor, View, Item
        whitelist = ['bmax', 'bmin', 'free', 'name', 'std', 'units', 'value']
        editable_traits = [
            trait for trait in self.editable_traits() if trait in whitelist
        ]
        if 'value' in editable_traits:
            i = editable_traits.index('value')
            v = editable_traits.pop(i)
            editable_traits.insert(
                i,
                Item(v, editor=RangeEditor(low_name='bmin', high_name='bmax')))
        view = View(editable_traits, buttons=['OK', 'Cancel'])
        return view
예제 #26
0
class config(HasTraits):
    uuid = traits.Str(desc="UUID")

    # Directories
    working_dir = Directory(mandatory=True,
                            desc="Location of the Nipype working directory")
    base_dir = Directory(
        os.path.abspath('.'),
        mandatory=True,
        desc='Base directory of data. (Should be subject-independent)')
    sink_dir = Directory(mandatory=True,
                         desc="Location where the BIP will store the results")
    crash_dir = Directory(mandatory=False,
                          desc="Location to store crash files")
    save_script_only = traits.Bool(False)
    # Execution
    run_using_plugin = Bool(
        False,
        usedefault=True,
        desc="True to run pipeline with plugin, False to run serially")
    plugin = traits.Enum("PBS",
                         "MultiProc",
                         "SGE",
                         "Condor",
                         usedefault=True,
                         desc="plugin to use, if run_using_plugin=True")
    plugin_args = traits.Dict({"qsub_args": "-q many"},
                              usedefault=True,
                              desc='Plugin arguments.')
    test_mode = Bool(
        False,
        mandatory=False,
        usedefault=True,
        desc='Affects whether where and if the workflow keeps its \
                            intermediary files. True to keep intermediary files. '
    )
    timeout = traits.Float(14.0)
    # Subjects
    #subjects = traits.List(traits.Str, mandatory=True, usedefault=True,
    #    desc="Subject id's. Note: These MUST match the subject id's in the \
    #                            Freesurfer directory. For simplicity, the subject id's should \
    #                            also match with the location of individual functional files.")
    #fwhm=traits.List(traits.Float())
    #copes_template = traits.String('%s/preproc/output/fwhm_%s/cope*.nii.gz')
    #varcopes_template = traits.String('%s/preproc/output/fwhm_%s/varcope*.nii.gz')
    #contrasts = traits.List(traits.Str,desc="contrasts")

    datagrabber = traits.Instance(Data, ())

    # Regression
    design_csv = traits.File(desc="design .csv file")
    reg_contrasts = traits.Code(
        desc=
        "function named reg_contrasts which takes in 0 args and returns contrasts"
    )
    run_mode = traits.Enum("flame1", "ols", "flame12")
    #Normalization
    norm_template = traits.File(desc='Template of files')
    use_mask = traits.Bool(False)
    mask_file = traits.File()
    #Correction:
    run_correction = traits.Bool(False)
    p_threshold = traits.Float(0.05)
    z_threshold = traits.Float(2.3)
    connectivity = traits.Int(26)
    do_randomize = traits.Bool(False)
    num_iterations = traits.Int(5000)
    # Advanced Options
    use_advanced_options = traits.Bool()
    advanced_script = traits.Code()

    # Buttons
    check_func_datagrabber = Button("Check")
예제 #27
0
class General(t.HasTraits):
    title = t.Str(t.Undefined)
    original_filename = t.File(t.Undefined)
    signal_kind = t.Str(t.Undefined)
    record_by = t.Enum('spectrum', 'image', default=t.Undefined)
예제 #28
0
class config(HasTraits):
    uuid = traits.Str(desc="UUID")
    desc = traits.Str(desc='Workflow description')
    # Directories
    working_dir = Directory(mandatory=True, desc="Location of the Nipype working directory")
    base_dir = Directory(os.path.abspath('.'),mandatory=True, desc='Base directory of data. (Should be subject-independent)')
    sink_dir = Directory(os.path.abspath('.'),mandatory=True, desc="Location where the BIP will store the results")
    field_dir = Directory(desc="Base directory of field-map data (Should be subject-independent) \
                                                 Set this value to None if you don't want fieldmap distortion correction")
    crash_dir = Directory(mandatory=False, desc="Location to store crash files")
    surf_dir = Directory(mandatory=True, desc= "Freesurfer subjects directory")

    # Execution

    run_using_plugin = Bool(False, usedefault=True, desc="True to run pipeline with plugin, False to run serially")
    plugin = traits.Enum("PBS", "PBSGraph","MultiProc", "SGE", "Condor",
        usedefault=True,
        desc="plugin to use, if run_using_plugin=True")
    plugin_args = traits.Dict({"qsub_args": "-q many"},
        usedefault=True, desc='Plugin arguments.')
    test_mode = Bool(False, mandatory=False, usedefault=True,
        desc='Affects whether where and if the workflow keeps its \
                            intermediary files. True to keep intermediary files. ')
    # Subjects

    subjects= traits.List(traits.Str, mandatory=True, usedefault=True,
        desc="Subject id's. Note: These MUST match the subject id's in the \
                                Freesurfer directory. For simplicity, the subject id's should \
                                also match with the location of individual functional files.")
    dwi_template = traits.String('%s/functional.nii.gz')
    bval_template = traits.String('%s/bval')
    bvec_template = traits.String('%s/fbvec')
    run_datagrabber_without_submitting = traits.Bool(desc="Run the datagrabber without \
    submitting to the cluster")
    timepoints_to_remove = traits.Int(0,usedefault=True)

    # Fieldmap

    use_fieldmap = Bool(False, mandatory=False, usedefault=True,
        desc='True to include fieldmap distortion correction. Note: field_dir \
                                     must be specified')
    magnitude_template = traits.String('%s/magnitude.nii.gz')
    phase_template = traits.String('%s/phase.nii.gz')
    TE_diff = traits.Float(desc='difference in B0 field map TEs')
    sigma = traits.Int(2, desc='2D spatial gaussing smoothing stdev (default = 2mm)')
    echospacing = traits.Float(desc="EPI echo spacing")

    # Bvecs
    do_rotate_bvecs = traits.Bool(True, usedefault=True)

    # Advanced Options
    use_advanced_options = traits.Bool()
    advanced_script = traits.Code()

    # Buttons
    check_func_datagrabber = Button("Check")
    check_field_datagrabber = Button("Check")

    def _check_func_datagrabber_fired(self):
        subs = self.subjects

        for s in subs:
            if not os.path.exists(os.path.join(self.base_dir,self.dwi_template % s)):
                print "ERROR", os.path.join(self.base_dir,self.dwi_template % s), "does NOT exist!"
                break
            else:
                print os.path.join(self.base_dir,self.dwi_template % s), "exists!"

    def _check_field_datagrabber_fired(self):
        subs = self.subjects

        for s in subs:
            if not os.path.exists(os.path.join(self.field_dir,self.magnitude_template % s)):
                print "ERROR:", os.path.join(self.field_dir,self.magnitude_template % s), "does NOT exist!"
                break
            else:
                print os.path.join(self.base_dir,self.magnitude_template % s), "exists!"
            if not os.path.exists(os.path.join(self.field_dir,self.phase_template % s)):
                print "ERROR:", os.path.join(self.field_dir,self.phase_template % s), "does NOT exist!"
                break
            else:
                print os.path.join(self.base_dir,self.phase_template % s), "exists!"
예제 #29
0
class LineInSignal1D(t.HasTraits):
    """Adds a vertical draggable line to a spectrum that reports its
    position to the position attribute of the class.

    Attributes:
    -----------
    position : float
        The position of the vertical line in the one dimensional signal. Moving
        the line changes the position but the reverse is not true.
    on : bool
        Turns on and off the line
    color : wx.Colour
        The color of the line. It automatically redraws the line.

    """
    position = t.Float()
    is_ok = t.Bool(False)
    on = t.Bool(False)
    # The following is disabled because as of traits 4.6 the Color trait
    # imports traitsui (!)
    # try:
    #     color = t.Color("black")
    # except ModuleNotFoundError:  # traitsui is not installed
    #     pass
    color_str = t.Str("black")

    def __init__(self, signal):
        if signal.axes_manager.signal_dimension != 1:
            raise SignalDimensionError(signal.axes_manager.signal_dimension, 1)

        self.signal = signal
        self.signal.plot()
        axis_dict = signal.axes_manager.signal_axes[0].get_axis_dictionary()
        am = AxesManager([
            axis_dict,
        ])
        am._axes[0].navigate = True
        # Set the position of the line in the middle of the spectral
        # range by default
        am._axes[0].index = int(round(am._axes[0].size / 2))
        self.axes_manager = am
        self.axes_manager.events.indices_changed.connect(
            self.update_position, [])
        self.on_trait_change(self.switch_on_off, 'on')

    def draw(self):
        self.signal._plot.signal_plot.figure.canvas.draw_idle()

    def switch_on_off(self, obj, trait_name, old, new):
        if not self.signal._plot.is_active:
            return

        if new is True and old is False:
            self._line = VerticalLineWidget(self.axes_manager)
            self._line.set_mpl_ax(self.signal._plot.signal_plot.ax)
            self._line.patch.set_linewidth(2)
            self._color_changed("black", "black")
            # There is not need to call draw because setting the
            # color calls it.

        elif new is False and old is True:
            self._line.close()
            self._line = None
            self.draw()

    def update_position(self, *args, **kwargs):
        if not self.signal._plot.is_active:
            return
        self.position = self.axes_manager.coordinates[0]

    def _color_changed(self, old, new):
        if self.on is False:
            return

        self._line.patch.set_color((
            self.color.Red() / 255.,
            self.color.Green() / 255.,
            self.color.Blue() / 255.,
        ))
        self.draw()
예제 #30
0
class config(BaseWorkflowConfig):
    uuid = traits.Str(desc="UUID")
    desc = traits.Str(desc='Workflow description')
    # Directories
    base_dir = Directory(
        os.path.abspath('.'),
        mandatory=True,
        desc='Base directory of data. (Should be subject-independent)')
    sink_dir = Directory(os.path.abspath('.'),
                         mandatory=True,
                         desc="Location where the BIP will store the results")
    field_dir = Directory(
        desc="Base directory of field-map data (Should be subject-independent) \
                                                 Set this value to None if you don't want fieldmap distortion correction"
    )
    surf_dir = Directory(mandatory=True, desc="Freesurfer subjects directory")

    # Subjects

    subjects = traits.List(
        traits.Str,
        mandatory=True,
        usedefault=True,
        desc="Subject id's. Note: These MUST match the subject id's in the \
                                Freesurfer directory. For simplicity, the subject id's should \
                                also match with the location of individual functional files."
    )
    dwi_template = traits.String('%s/functional.nii.gz')
    bval_template = traits.String('%s/bval')
    bvec_template = traits.String('%s/fbvec')
    run_datagrabber_without_submitting = traits.Bool(
        desc="Run the datagrabber without \
    submitting to the cluster")
    timepoints_to_remove = traits.Int(0, usedefault=True)

    # Fieldmap

    use_fieldmap = Bool(
        False,
        mandatory=False,
        usedefault=True,
        desc='True to include fieldmap distortion correction. Note: field_dir \
                                     must be specified')
    magnitude_template = traits.String('%s/magnitude.nii.gz')
    phase_template = traits.String('%s/phase.nii.gz')
    TE_diff = traits.Float(desc='difference in B0 field map TEs')
    sigma = traits.Int(
        2, desc='2D spatial gaussing smoothing stdev (default = 2mm)')
    echospacing = traits.Float(desc="EPI echo spacing")

    # Bvecs
    do_rotate_bvecs = traits.Bool(True, usedefault=True)

    # Advanced Options
    use_advanced_options = traits.Bool()
    advanced_script = traits.Code()

    # Buttons
    check_func_datagrabber = Button("Check")
    check_field_datagrabber = Button("Check")

    def _check_func_datagrabber_fired(self):
        subs = self.subjects

        for s in subs:
            if not os.path.exists(
                    os.path.join(self.base_dir, self.dwi_template % s)):
                print "ERROR", os.path.join(self.base_dir, self.dwi_template %
                                            s), "does NOT exist!"
                break
            else:
                print os.path.join(self.base_dir,
                                   self.dwi_template % s), "exists!"

    def _check_field_datagrabber_fired(self):
        subs = self.subjects

        for s in subs:
            if not os.path.exists(
                    os.path.join(self.field_dir, self.magnitude_template % s)):
                print "ERROR:", os.path.join(self.field_dir,
                                             self.magnitude_template %
                                             s), "does NOT exist!"
                break
            else:
                print os.path.join(self.base_dir,
                                   self.magnitude_template % s), "exists!"
            if not os.path.exists(
                    os.path.join(self.field_dir, self.phase_template % s)):
                print "ERROR:", os.path.join(
                    self.field_dir, self.phase_template % s), "does NOT exist!"
                break
            else:
                print os.path.join(self.base_dir,
                                   self.phase_template % s), "exists!"