Example #1
0
class DataModel(HasTraits):
    """This is the data to be plotted in the demo."""

    # The x values of the data (1D numpy array).
    x_index = Array()

    # The channel numbers (1D numpy array).
    y_index = Array()

    # The data.  The shape of this 2D array must be (y_index.size, x_index.size)
    data = Array()
Example #2
0
class RepeatedNodesRandomSelector(AbstractRandomSelector):

    repeated_nodes = Array(dtype=np.int32, shape=(None, ))

    def extract_preferential_attachment(self):
        return random.choice(self.repeated_nodes)

    def add_edge(self, source, target):
        if self._initialized_preferential_attachment:
            # this is embarassingly inefficient! Fix somehow!
            self.repeated_nodes = np.append(self.repeated_nodes, [source, target])

    def remove_edge(self, source, target):
        if self._initialized_preferential_attachment:
            self.repeated_nodes.sort()
            # algorithm supposes source and target are in the array!
            source_index = self.repeated_nodes.searchsorted(source)
            target_index = self.repeated_nodes.searchsorted(target)
            min_index = min(source_index, target_index)
            max_index = max(source_index, target_index)
            self.repeated_nodes = np.hstack(
                [self.repeated_nodes[:min_index],
                 self.repeated_nodes[min_index + 1:max_index],
                 self.repeated_nodes[max_index + 1:]])

    def remove_node(self, node):
        self._initialized_preferential_attachment = False
        self.repeated_nodes = np.zeros(0, dtype=np.int32)

    def add_node(self, node):
        if self._initialized_preferential_attachment:
            self.repeated_nodes = np.append(self.repeated_nodes, node)
class DataView(HasTraits):
    data = Array(dtype=object)

    def traits_view(self):
        ncolumns = len(self.data[0])
        w_table = min(WIDTH_CELL * ncolumns, MAX_WIDTH)
        w_view = min(w_table + W_MARGIN, MAX_WIDTH)
        return View(Group(
            Item('data',
                 editor=TabularEditor(adapter=Array2DAdapter(
                     ncolumns=ncolumns, format='%s', show_index=True)),
                 show_label=False,
                 width=w_table,
                 padding=10), ),
                    title='Annotations',
                    width=w_view,
                    height=800,
                    resizable=True,
                    buttons=OKCancelButtons)
Example #4
0
class Parameter(t.HasTraits):
    """Model parameter

    Attributes
    ----------
    value : float or array
        The value of the parameter for the current location. The value
        for other locations is stored in map.
    bmin, bmax: float
        Lower and upper bounds of the parameter value.
    twin : {None, Parameter}
        If it is not None, the value of the current parameter is
        a function of the given Parameter. The function is by default
        the identity function, but it can be defined by twin_function
    twin_function_expr: str
        Expression of the ``twin_function`` that enables setting a functional
        relationship between the parameter and its twin. If ``twin`` is not
        ``None``, the parameter value is calculated as the output of calling the
        twin function with the value of the twin parameter. The string is
        parsed using sympy, so permitted values are any valid sympy expressions
        of one variable. If the function is invertible the twin inverse function
        is set automatically.
    twin_inverse_function : str
        Expression of the ``twin_inverse_function`` that enables setting the
        value of the twin parameter. If ``twin`` is not
        ``None``, its value is set to the output of calling the
        twin inverse function with the value provided. The string is
        parsed using sympy, so permitted values are any valid sympy expressions
        of one variable.
    twin_function : function
        **Setting this attribute manually
        is deprecated in HyperSpy newer than 1.1.2. It will become private in
        HyperSpy 2.0. Please use ``twin_function_expr`` instead.**
    twin_inverse_function : function
        **Setting this attribute manually
        is deprecated in HyperSpy newer than 1.1.2. It will become private in
        HyperSpy 2.0. Please use ``twin_inverse_function_expr`` instead.**
    ext_force_positive : bool
        If True, the parameter value is set to be the absolute value
        of the input value i.e. if we set Parameter.value = -3, the
        value stored is 3 instead. This is useful to bound a value
        to be positive in an optimization without actually using an
        optimizer that supports bounding.
    ext_bounded : bool
        Similar to ext_force_positive, but in this case the bounds are
        defined by bmin and bmax. It is a better idea to use
        an optimizer that supports bounding though.

    Methods
    -------
    as_signal(field = 'values')
        Get a parameter map as a signal object
    plot()
        Plots the value of the Parameter at all locations.
    export(folder=None, name=None, format=None, save_std=False)
        Saves the value of the parameter map to the specified format
    connect, disconnect(function)
        Call the functions connected when the value attribute changes.

    """
    __number_of_elements = 1
    __value = 0
    __free = True
    _bounds = (None, None)
    __twin = None
    _axes_manager = None
    __ext_bounded = False
    __ext_force_positive = False

    # traitsui bugs out trying to make an editor for this, so always specify!
    # (it bugs out, because both editor shares the object, and Array editors
    # don't like non-sequence objects). TextEditor() works well, so does
    # RangeEditor() as it works with bmin/bmax.
    value = t.Property(t.Either([t.CFloat(0), Array()]))

    units = t.Str('')
    free = t.Property(t.CBool(True))

    bmin = t.Property(NoneFloat(), label="Lower bounds")
    bmax = t.Property(NoneFloat(), label="Upper bounds")
    _twin_function_expr = ""
    _twin_inverse_function_expr = ""
    twin_function = None
    _twin_inverse_function = None
    _twin_inverse_sympy = None

    def __init__(self):
        self._twins = set()
        self.events = Events()
        self.events.value_changed = Event("""
            Event that triggers when the `Parameter.value` changes.

            The event triggers after the internal state of the `Parameter` has
            been updated.

            Arguments
            ---------
            obj : Parameter
                The `Parameter` that the event belongs to
            value : {float | array}
                The new value of the parameter
            """,
                                          arguments=["obj", 'value'])
        self.std = None
        self.component = None
        self.grad = None
        self.name = ''
        self.units = ''
        self.map = None
        self.model = None
        self._whitelist = {
            '_id_name': None,
            'value': None,
            'std': None,
            'free': None,
            'units': None,
            'map': None,
            '_bounds': None,
            'ext_bounded': None,
            'name': None,
            'ext_force_positive': None,
            'twin_function_expr': None,
            'twin_inverse_function_expr': None,
            'self': ('id', None),
        }
        self._slicing_whitelist = {'map': 'inav'}

    def _load_dictionary(self, dictionary):
        """Load data from dictionary

        Parameters
        ----------
        dict : dictionary
            A dictionary containing at least the following items:
            _id_name : string
                _id_name of the original parameter, used to create the
                dictionary. Has to match with the self._id_name
            _whitelist : dictionary
                a dictionary, which keys are used as keywords to match with the
                parameter attributes.  For more information see
                :meth:`hyperspy.misc.export_dictionary.load_from_dictionary`
            * any field from _whitelist.keys() *
        Returns
        -------
        id_value : int
            the ID value of the original parameter, to be later used for setting
            up the correct twins

        """
        if dictionary['_id_name'] == self._id_name:
            load_from_dictionary(self, dictionary)
            return dictionary['self']
        else:
            raise ValueError(
                "_id_name of parameter and dictionary do not match, \nparameter._id_name = %s\
                    \ndictionary['_id_name'] = %s" %
                (self._id_name, dictionary['_id_name']))

    def __repr__(self):
        text = ''
        text += 'Parameter %s' % self.name
        if self.component is not None:
            text += ' of %s' % self.component._get_short_description()
        text = '<' + text + '>'
        return text

    def __len__(self):
        return self._number_of_elements

    @property
    def twin_function_expr(self):
        return self._twin_function_expr

    @twin_function_expr.setter
    def twin_function_expr(self, value):
        if not value:
            self.twin_function = None
            self.twin_inverse_function = None
            self._twin_function_expr = ""
            self._twin_inverse_sympy = None
            return
        expr = sympy.sympify(value)
        if len(expr.free_symbols) > 1:
            raise ValueError("The expression must contain only one variable.")
        elif len(expr.free_symbols) == 0:
            raise ValueError("The expression must contain one variable, "
                             "it contains none.")
        x = tuple(expr.free_symbols)[0]
        self.twin_function = lambdify(x, expr.evalf())
        self._twin_function_expr = value
        if not self.twin_inverse_function:
            y = sympy.Symbol(x.name + "2")
            try:
                inv = sympy.solveset(sympy.Eq(y, expr), x)
                self._twin_inverse_sympy = lambdify(y, inv)
                self._twin_inverse_function = None
            except BaseException:
                # Not all may have a suitable solution.
                self._twin_inverse_function = None
                self._twin_inverse_sympy = None
                _logger.warning(
                    "The function {} is not invertible. Setting the value of "
                    "{} will raise an AttributeError unless you set manually "
                    "``twin_inverse_function_expr``. Otherwise, set the "
                    "value of its twin parameter instead.".format(value, self))

    @property
    def twin_inverse_function_expr(self):
        if self.twin:
            return self._twin_inverse_function_expr
        else:
            return ""

    @twin_inverse_function_expr.setter
    def twin_inverse_function_expr(self, value):
        if not value:
            self.twin_inverse_function = None
            self._twin_inverse_function_expr = ""
            return
        expr = sympy.sympify(value)
        if len(expr.free_symbols) > 1:
            raise ValueError("The expression must contain only one variable.")
        elif len(expr.free_symbols) == 0:
            raise ValueError("The expression must contain one variable, "
                             "it contains none.")
        x = tuple(expr.free_symbols)[0]
        self._twin_inverse_function = lambdify(x, expr.evalf())
        self._twin_inverse_function_expr = value

    @property
    def twin_inverse_function(self):
        if (not self.twin_inverse_function_expr and self.twin_function_expr
                and self._twin_inverse_sympy):
            return lambda x: self._twin_inverse_sympy(x).pop()
        else:
            return self._twin_inverse_function

    @twin_inverse_function.setter
    def twin_inverse_function(self, value):
        self._twin_inverse_function = value

    def _get_value(self):
        if self.twin is None:
            return self.__value
        else:
            if self.twin_function:
                return self.twin_function(self.twin.value)
            else:
                return self.twin.value

    def _set_value(self, value):
        try:
            # Use try/except instead of hasattr("__len__") because a numpy
            # memmap has a __len__ wrapper even for numbers that raises a
            # TypeError when calling. See issue #349.
            if len(value) != self._number_of_elements:
                raise ValueError("The length of the parameter must be ",
                                 self._number_of_elements)
            else:
                if not isinstance(value, tuple):
                    value = tuple(value)
        except TypeError:
            if self._number_of_elements != 1:
                raise ValueError("The length of the parameter must be ",
                                 self._number_of_elements)
        old_value = self.__value

        if self.twin is not None:
            if self.twin_function is not None:
                if self.twin_inverse_function is not None:
                    self.twin.value = self.twin_inverse_function(value)
                    return
                else:
                    raise AttributeError(
                        "This parameter has a ``twin_function`` but"
                        "its ``twin_inverse_function`` is not defined.")
            else:
                self.twin.value = value
                return

        if self.ext_bounded is False:
            self.__value = value
        else:
            if self.ext_force_positive is True:
                value = np.abs(value)
            if self._number_of_elements == 1:
                if self.bmin is not None and value <= self.bmin:
                    self.__value = self.bmin
                elif self.bmax is not None and value >= self.bmax:
                    self.__value = self.bmax
                else:
                    self.__value = value
            else:
                bmin = (self.bmin if self.bmin is not None else -np.inf)
                bmax = (self.bmax if self.bmin is not None else np.inf)
                self.__value = np.clip(value, bmin, bmax)

        if (self._number_of_elements != 1
                and not isinstance(self.__value, tuple)):
            self.__value = tuple(self.__value)
        if old_value != self.__value:
            self.events.value_changed.trigger(value=self.__value, obj=self)
        self.trait_property_changed('value', old_value, self.__value)

    # Fix the parameter when coupled
    def _get_free(self):
        if self.twin is None:
            return self.__free
        else:
            return False

    def _set_free(self, arg):
        old_value = self.__free
        self.__free = arg
        if self.component is not None:
            self.component._update_free_parameters()
        self.trait_property_changed('free', old_value, self.__free)

    def _on_twin_update(self, value, twin=None):
        if (twin is not None and hasattr(twin, 'events')
                and hasattr(twin.events, 'value_changed')):
            with twin.events.value_changed.suppress_callback(
                    self._on_twin_update):
                self.events.value_changed.trigger(value=value, obj=self)
        else:
            self.events.value_changed.trigger(value=value, obj=self)

    def _set_twin(self, arg):
        if arg is None:
            if self.twin is not None:
                # Store the value of the twin in order to set the
                # value of the parameter when it is uncoupled
                twin_value = self.value
                if self in self.twin._twins:
                    self.twin._twins.remove(self)
                    self.twin.events.value_changed.disconnect(
                        self._on_twin_update)

                self.__twin = arg
                self.value = twin_value
        else:
            if self not in arg._twins:
                arg._twins.add(self)
                arg.events.value_changed.connect(self._on_twin_update,
                                                 ["value"])
            self.__twin = arg

        if self.component is not None:
            self.component._update_free_parameters()

    def _get_twin(self):
        return self.__twin

    twin = property(_get_twin, _set_twin)

    def _get_bmin(self):
        if self._number_of_elements == 1:
            return self._bounds[0]
        else:
            return self._bounds[0][0]

    def _set_bmin(self, arg):
        old_value = self.bmin
        if self._number_of_elements == 1:
            self._bounds = (arg, self.bmax)
        else:
            self._bounds = ((arg, self.bmax), ) * self._number_of_elements
        # Update the value to take into account the new bounds
        self.value = self.value
        self.trait_property_changed('bmin', old_value, arg)

    def _get_bmax(self):
        if self._number_of_elements == 1:
            return self._bounds[1]
        else:
            return self._bounds[0][1]

    def _set_bmax(self, arg):
        old_value = self.bmax
        if self._number_of_elements == 1:
            self._bounds = (self.bmin, arg)
        else:
            self._bounds = ((self.bmin, arg), ) * self._number_of_elements
        # Update the value to take into account the new bounds
        self.value = self.value
        self.trait_property_changed('bmax', old_value, arg)

    @property
    def _number_of_elements(self):
        return self.__number_of_elements

    @_number_of_elements.setter
    def _number_of_elements(self, arg):
        # Do nothing if the number of arguments stays the same
        if self.__number_of_elements == arg:
            return
        if arg <= 1:
            raise ValueError("Please provide an integer number equal "
                             "or greater to 1")
        self._bounds = ((self.bmin, self.bmax), ) * arg
        self.__number_of_elements = arg

        if arg == 1:
            self._Parameter__value = 0
        else:
            self._Parameter__value = (0, ) * arg
        if self.component is not None:
            self.component.update_number_parameters()

    @property
    def ext_bounded(self):
        return self.__ext_bounded

    @ext_bounded.setter
    def ext_bounded(self, arg):
        if arg is not self.__ext_bounded:
            self.__ext_bounded = arg
            # Update the value to take into account the new bounds
            self.value = self.value

    @property
    def ext_force_positive(self):
        return self.__ext_force_positive

    @ext_force_positive.setter
    def ext_force_positive(self, arg):
        if arg is not self.__ext_force_positive:
            self.__ext_force_positive = arg
            # Update the value to take into account the new bounds
            self.value = self.value

    def store_current_value_in_array(self):
        """Store the value and std attributes.

        See also
        --------
        fetch, assign_current_value_to_all

        """
        indices = self._axes_manager.indices[::-1]
        # If it is a single spectrum indices is ()
        if not indices:
            indices = (0, )
        self.map['values'][indices] = self.value
        self.map['is_set'][indices] = True
        if self.std is not None:
            self.map['std'][indices] = self.std

    def fetch(self):
        """Fetch the stored value and std attributes.


        See Also
        --------
        store_current_value_in_array, assign_current_value_to_all

        """
        indices = self._axes_manager.indices[::-1]
        # If it is a single spectrum indices is ()
        if not indices:
            indices = (0, )
        if self.map['is_set'][indices]:
            value = self.map['values'][indices]
            std = self.map['std'][indices]
            if isinstance(value, dArray):
                value = value.compute()
            if isinstance(std, dArray):
                std = std.compute()
            self.value = value
            self.std = std

    def assign_current_value_to_all(self, mask=None):
        """Assign the current value attribute to all the  indices

        Parameters
        ----------
        mask: {None, boolean numpy array}
            Set only the indices that are not masked i.e. where
            mask is False.

        See Also
        --------
        store_current_value_in_array, fetch

        """
        if mask is None:
            mask = np.zeros(self.map.shape, dtype='bool')
        self.map['values'][mask == False] = self.value
        self.map['is_set'][mask == False] = True

    def _create_array(self):
        """Create the map array to store the information in
        multidimensional datasets.

        """
        shape = self._axes_manager._navigation_shape_in_array
        if not shape:
            shape = [
                1,
            ]
        # Shape-1 fields in dtypes won’t be collapsed to scalars in a future
        # numpy version (see release notes numpy 1.17.0)
        if self._number_of_elements > 1:
            dtype_ = np.dtype([('values', 'float', self._number_of_elements),
                               ('std', 'float', self._number_of_elements),
                               ('is_set', 'bool')])
        else:
            dtype_ = np.dtype([('values', 'float'), ('std', 'float'),
                               ('is_set', 'bool')])
        if (self.map is None or self.map.shape != shape
                or self.map.dtype != dtype_):
            self.map = np.zeros(shape, dtype_)
            self.map['std'].fill(np.nan)
            # TODO: in the future this class should have access to
            # axes manager and should be able to fetch its own
            # values. Until then, the next line is necessary to avoid
            # erros when self.std is defined and the shape is different
            # from the newly defined arrays
            self.std = None

    def as_signal(self, field='values'):
        """Get a parameter map as a signal object.

        Please note that this method only works when the navigation
        dimension is greater than 0.

        Parameters
        ----------
        field : {'values', 'std', 'is_set'}

        Raises
        ------

        NavigationDimensionError : if the navigation dimension is 0

        """
        from hyperspy.signal import BaseSignal

        s = BaseSignal(data=self.map[field],
                       axes=self._axes_manager._get_navigation_axes_dicts())
        if self.component is not None and \
                self.component.active_is_multidimensional:
            s.data[np.logical_not(self.component._active_array)] = np.nan

        s.metadata.General.title = ("%s parameter" %
                                    self.name if self.component is None else
                                    "%s parameter of %s component" %
                                    (self.name, self.component.name))
        for axis in s.axes_manager._axes:
            axis.navigate = False
        if self._number_of_elements > 1:
            s.axes_manager._append_axis(size=self._number_of_elements,
                                        name=self.name,
                                        navigate=True)
        s._assign_subclass()
        if field == "values":
            # Add the variance if available
            std = self.as_signal(field="std")
            if not np.isnan(std.data).all():
                std.data = std.data**2
                std.metadata.General.title = "Variance"
                s.metadata.set_item("Signal.Noise_properties.variance", std)
        return s

    def plot(self, **kwargs):
        """Plot parameter signal.

        Parameters
        ----------
        **kwargs
            Any extra keyword arguments are passed to the signal plot.

        Example
        -------
        >>> parameter.plot() #doctest: +SKIP

        Set the minimum and maximum displayed values

        >>> parameter.plot(vmin=0, vmax=1) #doctest: +SKIP
        """
        self.as_signal().plot(**kwargs)

    def export(self, folder=None, name=None, format="hspy", save_std=False):
        """Save the data to a file.

        All the arguments are optional.

        Parameters
        ----------
        folder : str or None
            The path to the folder where the file will be saved.
             If `None` the current folder is used by default.
        name : str or None
            The name of the file. If `None` the Components name followed
             by the Parameter `name` attributes will be used by default.
              If a file with the same name exists the name will be
              modified by appending a number to the file path.
        save_std : bool
            If True, also the standard deviation will be saved
        format: str
            The extension of any file format supported by HyperSpy, default hspy

        """
        if format is None:
            format = "hspy"
        if name is None:
            name = self.component.name + '_' + self.name
        filename = incremental_filename(slugify(name) + '.' + format)
        if folder is not None:
            filename = os.path.join(folder, filename)
        self.as_signal().save(filename)
        if save_std is True:
            self.as_signal(field='std').save(append2pathname(filename, '_std'))

    def as_dictionary(self, fullcopy=True):
        """Returns parameter as a dictionary, saving all attributes from
        self._whitelist.keys() For more information see
        :meth:`hyperspy.misc.export_dictionary.export_to_dictionary`

        Parameters
        ----------
        fullcopy : Bool (optional, False)
            Copies of objects are stored, not references. If any found,
            functions will be pickled and signals converted to dictionaries
        Returns
        -------
        dic : dictionary with the following keys:
            _id_name : string
                _id_name of the original parameter, used to create the
                dictionary. Has to match with the self._id_name
            _twins : list
                a list of ids of the twins of the parameter
            _whitelist : dictionary
                a dictionary, which keys are used as keywords to match with the
                parameter attributes.  For more information see
                :meth:`hyperspy.misc.export_dictionary.export_to_dictionary`
            * any field from _whitelist.keys() *

        """
        dic = {'_twins': [id(t) for t in self._twins]}
        export_to_dictionary(self, self._whitelist, dic, fullcopy)
        return dic

    def default_traits_view(self):
        # As mentioned above, the default editor for
        # value = t.Property(t.Either([t.CFloat(0), Array()]))
        # gives a ValueError. We therefore implement default_traits_view so
        # that configure/edit_traits will still work straight out of the box.
        # A whitelist controls which traits to include in this view.
        from traitsui.api import RangeEditor, View, Item
        whitelist = ['bmax', 'bmin', 'free', 'name', 'std', 'units', 'value']
        editable_traits = [
            trait for trait in self.editable_traits() if trait in whitelist
        ]
        if 'value' in editable_traits:
            i = editable_traits.index('value')
            v = editable_traits.pop(i)
            editable_traits.insert(
                i,
                Item(v, editor=RangeEditor(low_name='bmin', high_name='bmax')))
        view = View(editable_traits, buttons=['OK', 'Cancel'])
        return view
Example #5
0
class Parameter(t.HasTraits):
    """Model parameter

    Attributes
    ----------
    value : float or array
        The value of the parameter for the current location. The value
        for other locations is stored in map.
    bmin, bmax: float
        Lower and upper bounds of the parameter value.
    twin : {None, Parameter}
        If it is not None, the value of the current parameter is
        a function of the given Parameter. The function is by default
        the identity function, but it can be defined by twin_function
    twin_function : function
        Function that, if selt.twin is not None, takes self.twin.value
        as its only argument and returns a float or array that is
        returned when getting Parameter.value
    twin_inverse_function : function
        The inverse of twin_function. If it is None then it is not
        possible to set the value of the parameter twin by setting
        the value of the current parameter.
    ext_force_positive : bool
        If True, the parameter value is set to be the absolute value
        of the input value i.e. if we set Parameter.value = -3, the
        value stored is 3 instead. This is useful to bound a value
        to be positive in an optimization without actually using an
        optimizer that supports bounding.
    ext_bounded : bool
        Similar to ext_force_positive, but in this case the bounds are
        defined by bmin and bmax. It is a better idea to use
        an optimizer that supports bounding though.

    Methods
    -------
    as_signal(field = 'values')
        Get a parameter map as a signal object
    plot()
        Plots the value of the Parameter at all locations.
    export(folder=None, name=None, format=None, save_std=False)
        Saves the value of the parameter map to the specified format
    connect, disconnect(function)
        Call the functions connected when the value attribute changes.

    """
    __number_of_elements = 1
    __value = 0
    __free = True
    _bounds = (None, None)
    __twin = None
    _axes_manager = None
    __ext_bounded = False
    __ext_force_positive = False

    # traitsui bugs out trying to make an editor for this, so always specify!
    # (it bugs out, because both editor shares the object, and Array editors
    # don't like non-sequence objects). TextEditor() works well, so does
    # RangeEditor() as it works with bmin/bmax.
    value = t.Property(t.Either([t.CFloat(0), Array()]))

    units = t.Str('')
    free = t.Property(t.CBool(True))

    bmin = t.Property(NoneFloat(), label="Lower bounds")
    bmax = t.Property(NoneFloat(), label="Upper bounds")

    def __init__(self):
        self._twins = set()
        self.connected_functions = list()
        self.twin_function = lambda x: x
        self.twin_inverse_function = lambda x: x
        self.std = None
        self.component = None
        self.grad = None
        self.name = ''
        self.map = None
        self.model = None

    def __repr__(self):
        text = ''
        text += 'Parameter %s' % self.name
        if self.component is not None:
            text += ' of %s' % self.component._get_short_description()
        text = '<' + text + '>'
        return text.encode('utf8')

    def __len__(self):
        return self._number_of_elements

    def connect(self, f):
        if f not in self.connected_functions:
            self.connected_functions.append(f)
            if self.twin:
                self.twin.connect(f)

    def disconnect(self, f):
        if f in self.connected_functions:
            self.connected_functions.remove(f)
            if self.twin:
                self.twin.disconnect(f)

    def _get_value(self):
        if self.twin is None:
            return self.__value
        else:
            return self.twin_function(self.twin.value)

    def _set_value(self, arg):
        try:
            # Use try/except instead of hasattr("__len__") because a numpy
            # memmap has a __len__ wrapper even for numbers that raises a
            # TypeError when calling. See issue #349.
            if len(arg) != self._number_of_elements:
                raise ValueError("The length of the parameter must be ",
                                 self._number_of_elements)
            else:
                if not isinstance(arg, tuple):
                    arg = tuple(arg)
        except TypeError:
            if self._number_of_elements != 1:
                raise ValueError("The length of the parameter must be ",
                                 self._number_of_elements)
        old_value = self.__value

        if self.twin is not None:
            if self.twin_inverse_function is not None:
                self.twin.value = self.twin_inverse_function(arg)
            return

        if self.ext_bounded is False:
            self.__value = arg
        else:
            if self.ext_force_positive is True:
                arg = np.abs(arg)
            if self._number_of_elements == 1:
                if self.bmin is not None and arg <= self.bmin:
                    self.__value = self.bmin
                elif self.bmax is not None and arg >= self.bmax:
                    self.__value = self.bmax
                else:
                    self.__value = arg
            else:
                bmin = (self.bmin if self.bmin is not None else -np.inf)
                bmax = (self.bmax if self.bmin is not None else np.inf)
                self.__value = np.clip(arg, bmin, bmax)

        if (self._number_of_elements != 1
                and not isinstance(self.__value, tuple)):
            self.__value = tuple(self.__value)
        if old_value != self.__value:
            for f in self.connected_functions:
                try:
                    f()
                except:
                    self.disconnect(f)
        self.trait_property_changed('value', old_value, self.__value)

    # Fix the parameter when coupled
    def _get_free(self):
        if self.twin is None:
            return self.__free
        else:
            return False

    def _set_free(self, arg):
        old_value = self.__free
        self.__free = arg
        if self.component is not None:
            self.component._update_free_parameters()
        self.trait_property_changed('free', old_value, self.__free)

    def _set_twin(self, arg):
        if arg is None:
            if self.twin is not None:
                # Store the value of the twin in order to set the
                # value of the parameter when it is uncoupled
                twin_value = self.value
                if self in self.twin._twins:
                    self.twin._twins.remove(self)
                    for f in self.connected_functions:
                        self.twin.disconnect(f)

                self.__twin = arg
                self.value = twin_value
        else:
            if self not in arg._twins:
                arg._twins.add(self)
                for f in self.connected_functions:
                    arg.connect(f)
            self.__twin = arg

        if self.component is not None:
            self.component._update_free_parameters()

    def _get_twin(self):
        return self.__twin

    twin = property(_get_twin, _set_twin)

    def _get_bmin(self):
        if self._number_of_elements == 1:
            return self._bounds[0]
        else:
            return self._bounds[0][0]

    def _set_bmin(self, arg):
        old_value = self.bmin
        if self._number_of_elements == 1:
            self._bounds = (arg, self.bmax)
        else:
            self._bounds = ((arg, self.bmax), ) * self._number_of_elements
        # Update the value to take into account the new bounds
        self.value = self.value
        self.trait_property_changed('bmin', old_value, arg)

    def _get_bmax(self):
        if self._number_of_elements == 1:
            return self._bounds[1]
        else:
            return self._bounds[0][1]

    def _set_bmax(self, arg):
        old_value = self.bmax
        if self._number_of_elements == 1:
            self._bounds = (self.bmin, arg)
        else:
            self._bounds = ((self.bmin, arg), ) * self._number_of_elements
        # Update the value to take into account the new bounds
        self.value = self.value
        self.trait_property_changed('bmax', old_value, arg)

    @property
    def _number_of_elements(self):
        return self.__number_of_elements

    @_number_of_elements.setter
    def _number_of_elements(self, arg):
        # Do nothing if the number of arguments stays the same
        if self.__number_of_elements == arg:
            return
        if arg <= 1:
            raise ValueError("Please provide an integer number equal "
                             "or greater to 1")
        self._bounds = ((self.bmin, self.bmax), ) * arg
        self.__number_of_elements = arg

        if arg == 1:
            self._Parameter__value = 0
        else:
            self._Parameter__value = (0, ) * arg
        if self.component is not None:
            self.component.update_number_parameters()

    @property
    def ext_bounded(self):
        return self.__ext_bounded

    @ext_bounded.setter
    def ext_bounded(self, arg):
        if arg is not self.__ext_bounded:
            self.__ext_bounded = arg
            # Update the value to take into account the new bounds
            self.value = self.value

    @property
    def ext_force_positive(self):
        return self.__ext_force_positive

    @ext_force_positive.setter
    def ext_force_positive(self, arg):
        if arg is not self.__ext_force_positive:
            self.__ext_force_positive = arg
            # Update the value to take into account the new bounds
            self.value = self.value

    def store_current_value_in_array(self):
        """Store the value and std attributes.

        See also
        --------
        fetch, assign_current_value_to_all

        """
        indices = self._axes_manager.indices[::-1]
        # If it is a single spectrum indices is ()
        if not indices:
            indices = (0, )
        self.map['values'][indices] = self.value
        self.map['is_set'][indices] = True
        if self.std is not None:
            self.map['std'][indices] = self.std

    def fetch(self):
        """Fetch the stored value and std attributes.


        See Also
        --------
        store_current_value_in_array, assign_current_value_to_all

        """
        indices = self._axes_manager.indices[::-1]
        # If it is a single spectrum indices is ()
        if not indices:
            indices = (0, )
        if self.map['is_set'][indices]:
            self.value = self.map['values'][indices]
            self.std = self.map['std'][indices]

    def assign_current_value_to_all(self, mask=None):
        '''Assign the current value attribute to all the  indices

        Parameters
        ----------
        mask: {None, boolean numpy array}
            Set only the indices that are not masked i.e. where
            mask is False.

        See Also
        --------
        store_current_value_in_array, fetch

        '''
        if mask is None:
            mask = np.zeros(self.map.shape, dtype='bool')
        self.map['values'][mask == False] = self.value
        self.map['is_set'][mask == False] = True

    def _create_array(self):
        """Create the map array to store the information in
        multidimensional datasets.

        """
        shape = self._axes_manager._navigation_shape_in_array
        if not shape:
            shape = [
                1,
            ]
        dtype_ = np.dtype([('values', 'float', self._number_of_elements),
                           ('std', 'float', self._number_of_elements),
                           ('is_set', 'bool', 1)])
        if (self.map is None or self.map.shape != shape
                or self.map.dtype != dtype_):
            self.map = np.zeros(shape, dtype_)
            self.map['std'].fill(np.nan)
            # TODO: in the future this class should have access to
            # axes manager and should be able to fetch its own
            # values. Until then, the next line is necessary to avoid
            # erros when self.std is defined and the shape is different
            # from the newly defined arrays
            self.std = None

    def as_signal(self, field='values'):
        """Get a parameter map as a signal object.

        Please note that this method only works when the navigation
        dimension is greater than 0.

        Parameters
        ----------
        field : {'values', 'std', 'is_set'}

        Raises
        ------

        NavigationDimensionError : if the navigation dimension is 0

        """
        from hyperspy.signal import Signal
        if self._axes_manager.navigation_dimension == 0:
            raise NavigationDimensionError(0, '>0')

        s = Signal(data=self.map[field],
                   axes=self._axes_manager._get_navigation_axes_dicts())
        if self.component.active_is_multidimensional:
            s.data[np.logical_not(self.component._active_array)] = np.nan
        s.metadata.General.title = ("%s parameter" %
                                    self.name if self.component is None else
                                    "%s parameter of %s component" %
                                    (self.name, self.component.name))
        for axis in s.axes_manager._axes:
            axis.navigate = False
        if self._number_of_elements > 1:
            s.axes_manager.append_axis(size=self._number_of_elements,
                                       name=self.name,
                                       navigate=True)
        return s

    def plot(self):
        self.as_signal().plot()

    def export(self, folder=None, name=None, format=None, save_std=False):
        '''Save the data to a file.

        All the arguments are optional.

        Parameters
        ----------
        folder : str or None
            The path to the folder where the file will be saved.
             If `None` the current folder is used by default.
        name : str or None
            The name of the file. If `None` the Components name followed
             by the Parameter `name` attributes will be used by default.
              If a file with the same name exists the name will be
              modified by appending a number to the file path.
        save_std : bool
            If True, also the standard deviation will be saved

        '''
        if format is None:
            format = preferences.General.default_export_format
        if name is None:
            name = self.component.name + '_' + self.name
        filename = incremental_filename(slugify(name) + '.' + format)
        if folder is not None:
            filename = os.path.join(folder, filename)
        self.as_signal().save(filename)
        if save_std is True:
            self.as_signal(field='std').save(append2pathname(filename, '_std'))

    def default_traits_view(self):
        # As mentioned above, the default editor for
        # value = t.Property(t.Either([t.CFloat(0), Array()]))
        # gives a ValueError. We therefore implement default_traits_view so
        # that configure/edit_traits will still work straight out of the box.
        # A whitelist controls which traits to include in this view.
        from traitsui.api import RangeEditor, View, Item
        whitelist = ['bmax', 'bmin', 'free', 'name', 'std', 'units', 'value']
        editable_traits = [
            trait for trait in self.editable_traits() if trait in whitelist
        ]
        if 'value' in editable_traits:
            i = editable_traits.index('value')
            v = editable_traits.pop(i)
            editable_traits.insert(
                i,
                Item(v, editor=RangeEditor(low_name='bmin', high_name='bmax')))
        view = View(editable_traits, buttons=['OK', 'Cancel'])
        return view
Example #6
0
class Parameter(t.HasTraits):
    """Model parameter

    Attributes
    ----------
    value : float or array
        The value of the parameter for the current location. The value
        for other locations is stored in map.
    bmin, bmax: float
        Lower and upper bounds of the parameter value.
    twin : {None, Parameter}
        If it is not None, the value of the current parameter is
        a function of the given Parameter. The function is by default
        the identity function, but it can be defined by twin_function
    twin_function : function
        Function that, if selt.twin is not None, takes self.twin.value
        as its only argument and returns a float or array that is
        returned when getting Parameter.value
    twin_inverse_function : function
        The inverse of twin_function. If it is None then it is not
        possible to set the value of the parameter twin by setting
        the value of the current parameter.
    ext_force_positive : bool
        If True, the parameter value is set to be the absolute value
        of the input value i.e. if we set Parameter.value = -3, the
        value stored is 3 instead. This is useful to bound a value
        to be positive in an optimization without actually using an
        optimizer that supports bounding.
    ext_bounded : bool
        Similar to ext_force_positive, but in this case the bounds are
        defined by bmin and bmax. It is a better idea to use
        an optimizer that supports bounding though.

    Methods
    -------
    as_signal(field = 'values')
        Get a parameter map as a signal object
    plot()
        Plots the value of the Parameter at all locations.
    export(folder=None, name=None, format=None, save_std=False)
        Saves the value of the parameter map to the specified format
    connect, disconnect(function)
        Call the functions connected when the value attribute changes.

    """
    __number_of_elements = 1
    __value = 0
    __free = True
    _bounds = (None, None)
    __twin = None
    _axes_manager = None
    __ext_bounded = False
    __ext_force_positive = False

    # traitsui bugs out trying to make an editor for this, so always specify!
    # (it bugs out, because both editor shares the object, and Array editors
    # don't like non-sequence objects). TextEditor() works well, so does
    # RangeEditor() as it works with bmin/bmax.
    value = t.Property(t.Either([t.CFloat(0), Array()]))

    units = t.Str('')
    free = t.Property(t.CBool(True))

    bmin = t.Property(NoneFloat(), label="Lower bounds")
    bmax = t.Property(NoneFloat(), label="Upper bounds")

    def __init__(self):
        self._twins = set()
        self.events = Events()
        self.events.value_changed = Event("""
            Event that triggers when the `Parameter.value` changes.

            The event triggers after the internal state of the `Parameter` has
            been updated.

            Arguments
            ---------
            obj : Parameter
                The `Parameter` that the event belongs to
            value : {float | array}
                The new value of the parameter
            """,
                                          arguments=["obj", 'value'])
        self.twin_function = lambda x: x
        self.twin_inverse_function = lambda x: x
        self.std = None
        self.component = None
        self.grad = None
        self.name = ''
        self.units = ''
        self.map = None
        self.model = None
        self._whitelist = {
            '_id_name': None,
            'value': None,
            'std': None,
            'free': None,
            'units': None,
            'map': None,
            '_bounds': None,
            'ext_bounded': None,
            'name': None,
            'ext_force_positive': None,
            'self': ('id', None),
            'twin_function': ('fn', None),
            'twin_inverse_function': ('fn', None),
        }
        self._slicing_whitelist = {'map': 'inav'}

    def _load_dictionary(self, dictionary):
        """Load data from dictionary

        Parameters
        ----------
        dict : dictionary
            A dictionary containing at least the following items:
            _id_name : string
                _id_name of the original parameter, used to create the
                dictionary. Has to match with the self._id_name
            _whitelist : dictionary
                a dictionary, which keys are used as keywords to match with the
                parameter attributes.  For more information see
                :meth:`hyperspy.misc.export_dictionary.load_from_dictionary`
            * any field from _whitelist.keys() *
        Returns
        -------
        id_value : int
            the ID value of the original parameter, to be later used for setting
            up the correct twins

        """
        if dictionary['_id_name'] == self._id_name:
            load_from_dictionary(self, dictionary)
            return dictionary['self']
        else:
            raise ValueError(
                "_id_name of parameter and dictionary do not match, \nparameter._id_name = %s\
                    \ndictionary['_id_name'] = %s" %
                (self._id_name, dictionary['_id_name']))

    def __repr__(self):
        text = ''
        text += 'Parameter %s' % self.name
        if self.component is not None:
            text += ' of %s' % self.component._get_short_description()
        text = '<' + text + '>'
        return text

    def __len__(self):
        return self._number_of_elements

    def _get_value(self):
        if self.twin is None:
            return self.__value
        else:
            return self.twin_function(self.twin.value)

    def _set_value(self, value):
        try:
            # Use try/except instead of hasattr("__len__") because a numpy
            # memmap has a __len__ wrapper even for numbers that raises a
            # TypeError when calling. See issue #349.
            if len(value) != self._number_of_elements:
                raise ValueError("The length of the parameter must be ",
                                 self._number_of_elements)
            else:
                if not isinstance(value, tuple):
                    value = tuple(value)
        except TypeError:
            if self._number_of_elements != 1:
                raise ValueError("The length of the parameter must be ",
                                 self._number_of_elements)
        old_value = self.__value

        if self.twin is not None:
            if self.twin_inverse_function is not None:
                self.twin.value = self.twin_inverse_function(value)
            return

        if self.ext_bounded is False:
            self.__value = value
        else:
            if self.ext_force_positive is True:
                value = np.abs(value)
            if self._number_of_elements == 1:
                if self.bmin is not None and value <= self.bmin:
                    self.__value = self.bmin
                elif self.bmax is not None and value >= self.bmax:
                    self.__value = self.bmax
                else:
                    self.__value = value
            else:
                bmin = (self.bmin if self.bmin is not None else -np.inf)
                bmax = (self.bmax if self.bmin is not None else np.inf)
                self.__value = np.clip(value, bmin, bmax)

        if (self._number_of_elements != 1
                and not isinstance(self.__value, tuple)):
            self.__value = tuple(self.__value)
        if old_value != self.__value:
            self.events.value_changed.trigger(value=self.__value, obj=self)
        self.trait_property_changed('value', old_value, self.__value)

    # Fix the parameter when coupled
    def _get_free(self):
        if self.twin is None:
            return self.__free
        else:
            return False

    def _set_free(self, arg):
        old_value = self.__free
        self.__free = arg
        if self.component is not None:
            self.component._update_free_parameters()
        self.trait_property_changed('free', old_value, self.__free)

    def _on_twin_update(self, value, twin=None):
        if (twin is not None and hasattr(twin, 'events')
                and hasattr(twin.events, 'value_changed')):
            with twin.events.value_changed.suppress_callback(
                    self._on_twin_update):
                self.events.value_changed.trigger(value=value, obj=self)
        else:
            self.events.value_changed.trigger(value=value, obj=self)

    def _set_twin(self, arg):
        if arg is None:
            if self.twin is not None:
                # Store the value of the twin in order to set the
                # value of the parameter when it is uncoupled
                twin_value = self.value
                if self in self.twin._twins:
                    self.twin._twins.remove(self)
                    self.twin.events.value_changed.disconnect(
                        self._on_twin_update)

                self.__twin = arg
                self.value = twin_value
        else:
            if self not in arg._twins:
                arg._twins.add(self)
                arg.events.value_changed.connect(self._on_twin_update,
                                                 ["value"])
            self.__twin = arg

        if self.component is not None:
            self.component._update_free_parameters()

    def _get_twin(self):
        return self.__twin

    twin = property(_get_twin, _set_twin)

    def _get_bmin(self):
        if self._number_of_elements == 1:
            return self._bounds[0]
        else:
            return self._bounds[0][0]

    def _set_bmin(self, arg):
        old_value = self.bmin
        if self._number_of_elements == 1:
            self._bounds = (arg, self.bmax)
        else:
            self._bounds = ((arg, self.bmax), ) * self._number_of_elements
        # Update the value to take into account the new bounds
        self.value = self.value
        self.trait_property_changed('bmin', old_value, arg)

    def _get_bmax(self):
        if self._number_of_elements == 1:
            return self._bounds[1]
        else:
            return self._bounds[0][1]

    def _set_bmax(self, arg):
        old_value = self.bmax
        if self._number_of_elements == 1:
            self._bounds = (self.bmin, arg)
        else:
            self._bounds = ((self.bmin, arg), ) * self._number_of_elements
        # Update the value to take into account the new bounds
        self.value = self.value
        self.trait_property_changed('bmax', old_value, arg)

    @property
    def _number_of_elements(self):
        return self.__number_of_elements

    @_number_of_elements.setter
    def _number_of_elements(self, arg):
        # Do nothing if the number of arguments stays the same
        if self.__number_of_elements == arg:
            return
        if arg <= 1:
            raise ValueError("Please provide an integer number equal "
                             "or greater to 1")
        self._bounds = ((self.bmin, self.bmax), ) * arg
        self.__number_of_elements = arg

        if arg == 1:
            self._Parameter__value = 0
        else:
            self._Parameter__value = (0, ) * arg
        if self.component is not None:
            self.component.update_number_parameters()

    @property
    def ext_bounded(self):
        return self.__ext_bounded

    @ext_bounded.setter
    def ext_bounded(self, arg):
        if arg is not self.__ext_bounded:
            self.__ext_bounded = arg
            # Update the value to take into account the new bounds
            self.value = self.value

    @property
    def ext_force_positive(self):
        return self.__ext_force_positive

    @ext_force_positive.setter
    def ext_force_positive(self, arg):
        if arg is not self.__ext_force_positive:
            self.__ext_force_positive = arg
            # Update the value to take into account the new bounds
            self.value = self.value

    def store_current_value_in_array(self):
        """Store the value and std attributes.

        See also
        --------
        fetch, assign_current_value_to_all

        """
        indices = self._axes_manager.indices[::-1]
        # If it is a single spectrum indices is ()
        if not indices:
            indices = (0, )
        self.map['values'][indices] = self.value
        self.map['is_set'][indices] = True
        if self.std is not None:
            self.map['std'][indices] = self.std

    def fetch(self):
        """Fetch the stored value and std attributes.


        See Also
        --------
        store_current_value_in_array, assign_current_value_to_all

        """
        indices = self._axes_manager.indices[::-1]
        # If it is a single spectrum indices is ()
        if not indices:
            indices = (0, )
        if self.map['is_set'][indices]:
            self.value = self.map['values'][indices]
            self.std = self.map['std'][indices]

    def assign_current_value_to_all(self, mask=None):
        """Assign the current value attribute to all the  indices

        Parameters
        ----------
        mask: {None, boolean numpy array}
            Set only the indices that are not masked i.e. where
            mask is False.

        See Also
        --------
        store_current_value_in_array, fetch

        """
        if mask is None:
            mask = np.zeros(self.map.shape, dtype='bool')
        self.map['values'][mask == False] = self.value
        self.map['is_set'][mask == False] = True

    def _create_array(self):
        """Create the map array to store the information in
        multidimensional datasets.

        """
        shape = self._axes_manager._navigation_shape_in_array
        if not shape:
            shape = [
                1,
            ]
        dtype_ = np.dtype([('values', 'float', self._number_of_elements),
                           ('std', 'float', self._number_of_elements),
                           ('is_set', 'bool', 1)])
        if (self.map is None or self.map.shape != shape
                or self.map.dtype != dtype_):
            self.map = np.zeros(shape, dtype_)
            self.map['std'].fill(np.nan)
            # TODO: in the future this class should have access to
            # axes manager and should be able to fetch its own
            # values. Until then, the next line is necessary to avoid
            # erros when self.std is defined and the shape is different
            # from the newly defined arrays
            self.std = None

    def as_signal(self, field='values'):
        """Get a parameter map as a signal object.

        Please note that this method only works when the navigation
        dimension is greater than 0.

        Parameters
        ----------
        field : {'values', 'std', 'is_set'}

        Raises
        ------

        NavigationDimensionError : if the navigation dimension is 0

        """
        from hyperspy.signal import BaseSignal

        s = BaseSignal(data=self.map[field],
                       axes=self._axes_manager._get_navigation_axes_dicts())
        if self.component is not None and \
                self.component.active_is_multidimensional:
            s.data[np.logical_not(self.component._active_array)] = np.nan

        s.metadata.General.title = ("%s parameter" %
                                    self.name if self.component is None else
                                    "%s parameter of %s component" %
                                    (self.name, self.component.name))
        for axis in s.axes_manager._axes:
            axis.navigate = False
        if self._number_of_elements > 1:
            s.axes_manager._append_axis(size=self._number_of_elements,
                                        name=self.name,
                                        navigate=True)
        s._assign_subclass()
        if field == "values":
            # Add the variance if available
            std = self.as_signal(field="std")
            if not np.isnan(std.data).all():
                std.data = std.data**2
                std.metadata.General.title = "Variance"
                s.metadata.set_item("Signal.Noise_properties.variance", std)
        return s

    def plot(self, **kwargs):
        """Plot parameter signal.

        Parameters
        ----------
        **kwargs
            Any extra keyword arguments are passed to the signal plot.

        Example
        -------
        >>> parameter.plot()

        Set the minimum and maximum displayed values

        >>> parameter.plot(vmin=0, vmax=1)
        """
        self.as_signal().plot(**kwargs)

    def export(self, folder=None, name=None, format=None, save_std=False):
        """Save the data to a file.

        All the arguments are optional.

        Parameters
        ----------
        folder : str or None
            The path to the folder where the file will be saved.
             If `None` the current folder is used by default.
        name : str or None
            The name of the file. If `None` the Components name followed
             by the Parameter `name` attributes will be used by default.
              If a file with the same name exists the name will be
              modified by appending a number to the file path.
        save_std : bool
            If True, also the standard deviation will be saved

        """
        if format is None:
            format = preferences.General.default_export_format
        if name is None:
            name = self.component.name + '_' + self.name
        filename = incremental_filename(slugify(name) + '.' + format)
        if folder is not None:
            filename = os.path.join(folder, filename)
        self.as_signal().save(filename)
        if save_std is True:
            self.as_signal(field='std').save(append2pathname(filename, '_std'))

    def as_dictionary(self, fullcopy=True):
        """Returns parameter as a dictionary, saving all attributes from
        self._whitelist.keys() For more information see
        :meth:`hyperspy.misc.export_dictionary.export_to_dictionary`

        Parameters
        ----------
        fullcopy : Bool (optional, False)
            Copies of objects are stored, not references. If any found,
            functions will be pickled and signals converted to dictionaries
        Returns
        -------
        dic : dictionary with the following keys:
            _id_name : string
                _id_name of the original parameter, used to create the
                dictionary. Has to match with the self._id_name
            _twins : list
                a list of ids of the twins of the parameter
            _whitelist : dictionary
                a dictionary, which keys are used as keywords to match with the
                parameter attributes.  For more information see
                :meth:`hyperspy.misc.export_dictionary.export_to_dictionary`
            * any field from _whitelist.keys() *

        """
        dic = {'_twins': [id(t) for t in self._twins]}
        export_to_dictionary(self, self._whitelist, dic, fullcopy)
        return dic

    def default_traits_view(self):
        # As mentioned above, the default editor for
        # value = t.Property(t.Either([t.CFloat(0), Array()]))
        # gives a ValueError. We therefore implement default_traits_view so
        # that configure/edit_traits will still work straight out of the box.
        # A whitelist controls which traits to include in this view.
        from traitsui.api import RangeEditor, View, Item
        whitelist = ['bmax', 'bmin', 'free', 'name', 'std', 'units', 'value']
        editable_traits = [
            trait for trait in self.editable_traits() if trait in whitelist
        ]
        if 'value' in editable_traits:
            i = editable_traits.index('value')
            v = editable_traits.pop(i)
            editable_traits.insert(
                i,
                Item(v, editor=RangeEditor(low_name='bmin', high_name='bmax')))
        view = View(editable_traits, buttons=['OK', 'Cancel'])
        return view

    def _interactive_slider_bounds(self, index=None):
        """Guesstimates the bounds for the slider. They will probably have to
        be changed later by the user.
        """
        fraction = 10.
        _min, _max, step = None, None, None
        value = self.value if index is None else self.value[index]
        if self.bmin is not None:
            _min = self.bmin
        if self.bmax is not None:
            _max = self.bmax
        if _max is None and _min is not None:
            _max = value + fraction * (value - _min)
        if _min is None and _max is not None:
            _min = value - fraction * (_max - value)
        if _min is None and _max is None:
            if self is self.component._position:
                axis = self._axes_manager.signal_axes[-1]
                _min = axis.axis.min()
                _max = axis.axis.max()
                step = np.abs(axis.scale)
            else:
                _max = value + np.abs(value * fraction)
                _min = value - np.abs(value * fraction)
        if step is None:
            step = (_max - _min) * 0.001
        return {'min': _min, 'max': _max, 'step': step}

    def _interactive_update(self, value=None, index=None):
        """Callback function for the widgets, to update the value
        """
        if value is not None:
            if index is None:
                self.value = value['new']
            else:
                self.value = self.value[:index] + (value['new'],) +\
                    self.value[index + 1:]

    def notebook_interaction(self, display=True):
        """Creates interactive notebook widgets for the parameter, if
        available.
        Requires `ipywidgets` to be installed.
        Parameters
        ----------
        display : bool
            if True (default), attempts to display the parameter widget.
            Otherwise returns the formatted widget object.
        """
        from ipywidgets import VBox
        from traitlets import TraitError as TraitletError
        from IPython.display import display as ip_display
        try:
            if self._number_of_elements == 1:
                container = self._create_notebook_widget()
            else:
                children = [
                    self._create_notebook_widget(index=i)
                    for i in range(self._number_of_elements)
                ]
                container = VBox(children)
            if not display:
                return container
            ip_display(container)
        except TraitletError:
            if display:
                _logger.info('This function is only avialable when running in'
                             ' a notebook')
            else:
                raise

    def _create_notebook_widget(self, index=None):

        from ipywidgets import (FloatSlider, FloatText, Layout, HBox)

        widget_bounds = self._interactive_slider_bounds(index=index)
        thismin = FloatText(
            value=widget_bounds['min'],
            description='min',
            layout=Layout(flex='0 1 auto', width='auto'),
        )
        thismax = FloatText(
            value=widget_bounds['max'],
            description='max',
            layout=Layout(flex='0 1 auto', width='auto'),
        )
        current_value = self.value if index is None else self.value[index]
        current_name = self.name
        if index is not None:
            current_name += '_{}'.format(index)
        widget = FloatSlider(value=current_value,
                             min=thismin.value,
                             max=thismax.value,
                             step=widget_bounds['step'],
                             description=current_name,
                             layout=Layout(flex='1 1 auto', width='auto'))

        def on_min_change(change):
            if widget.max > change['new']:
                widget.min = change['new']
                widget.step = np.abs(widget.max - widget.min) * 0.001

        def on_max_change(change):
            if widget.min < change['new']:
                widget.max = change['new']
                widget.step = np.abs(widget.max - widget.min) * 0.001

        thismin.observe(on_min_change, names='value')
        thismax.observe(on_max_change, names='value')

        this_observed = functools.partial(self._interactive_update,
                                          index=index)

        widget.observe(this_observed, names='value')
        container = HBox((thismin, widget, thismax))
        return container
Example #7
0
class ThetaScatterPlot(ModelView, PyannoPlotContainer):
    """Defines a view of the annotator accuracy parameters, theta.

    The view consists in a Chaco plot that displays the theta parameter for
    each annotator, and samples from the posterior distribution over theta
    with a combination of a scatter plot and a candle plot.
    """

    #### Traits definition ####################################################

    theta_samples_valid = Bool(False)
    theta_samples = Array(dtype=float, shape=(None, None))

    # return value for "Copy" action on plot
    data = DictStrAny

    def _data_default(self):
        return {'theta': self.model.theta, 'theta_samples': None}

    @on_trait_change('redraw,theta_samples,theta_samples_valid')
    def _update_data(self):
        if self.theta_samples_valid:
            theta_samples = self.theta_samples
        else:
            theta_samples = None

        self.data['theta'] = self.model.theta
        self.data['theta_samples'] = theta_samples

    #### plot-related traits
    title = Str('Accuracy (theta)')

    theta_plot_data = Instance(ArrayPlotData)
    theta_plot = Any

    redraw = Event


    ### Plot definition #######################################################

    def _compute_range2d(self):
        low = min(0.6, self.model.theta.min()-0.05)
        if self.theta_samples_valid:
            low = min(low, self.theta_samples.min()-0.05)
        range2d = DataRange2D(low=(0., low),
                              high=(self.model.theta.shape[0]+1, 1.))
        return range2d


    @on_trait_change('redraw', post_init=True)
    def _update_range2d(self):
        self.theta_plot.range2d = self._compute_range2d()


    def _theta_plot_default(self):
        """Create plot of theta parameters."""

        # We plot both the thetas and the samples from the posterior; if the
        # latter are not defined, the corresponding ArrayPlotData names
        # should be set to an empty list, so that they are not displayed
        theta = self.model.theta
        theta_len = theta.shape[0]

        # create the plot data
        if not self.theta_plot_data:
            self.theta_plot_data = ArrayPlotData()
            self._update_plot_data()

        # create the plot
        theta_plot = Plot(self.theta_plot_data)

        for idx in range(theta_len):
            # candle plot summarizing samples over the posterior
            theta_plot.candle_plot((_w_idx('index', idx),
                                    _w_idx('min', idx),
                                    _w_idx('barmin', idx),
                                    _w_idx('avg', idx),
                                    _w_idx('barmax', idx),
                                    _w_idx('max', idx)),
                                    color = get_annotator_color(idx),
                                    bar_line_color = "black",
                                    stem_color = "blue",
                                    center_color = "red",
                                    center_width = 2)

            # plot of raw samples
            theta_plot.plot((_w_idx('ysamples', idx),
                             _w_idx('xsamples', idx)),
                            type='scatter',
                            color='black',
                            marker='dot',
                            line_width=0.5,
                            marker_size=1)

            # plot current parameters
            theta_plot.plot((_w_idx('y', idx), _w_idx('x', idx)),
                            type='scatter',
                            color=get_annotator_color(idx),
                            marker='plus',
                            marker_size=8,
                            line_width=2)

        # adjust axis bounds
        theta_plot.range2d = self._compute_range2d()

        # remove horizontal grid and axis
        theta_plot.underlays = [theta_plot.x_grid, theta_plot.y_axis]

        # create new horizontal axis
        label_list = [str(i) for i in range(theta_len)]

        label_axis = LabelAxis(
            theta_plot,
            orientation = 'bottom',
            positions = list(range(1, theta_len+1)),
            labels = label_list,
            label_rotation = 0
        )
        # use a FixedScale tick generator with a resolution of 1
        label_axis.tick_generator = ScalesTickGenerator(scale=FixedScale(1.))

        theta_plot.index_axis = label_axis
        theta_plot.underlays.append(label_axis)
        theta_plot.padding = 25
        theta_plot.padding_left = 40
        theta_plot.aspect_ratio = 1.0

        container = VPlotContainer()
        container.add(theta_plot)
        container.bgcolor = 0xFFFFFF

        self.decorate_plot(container, theta)
        self._set_title(theta_plot)

        return container


    ### Handle plot data ######################################################

    def _samples_names_and_values(self, idx):
        """Return a list of names and values for the samples PlotData."""

        # In the following code, we rely on lazy evaluation of the
        # X if CONDITION else Y statements to return a default value if the
        # theta samples are not currently defined, or the real value if they
        # are.

        invalid = not self.theta_samples_valid
        samples = [] if invalid else np.sort(self.theta_samples[:,idx])
        nsamples = None if invalid else samples.shape[0]
        perc5 = None if invalid else samples[int(nsamples*0.05)]
        perc95 = None if invalid else samples[int(nsamples*0.95)]

        data_dict = {
            'xsamples':
                [] if invalid else samples,
            'ysamples':
                [] if invalid else (
                    np.random.random(size=(nsamples,))*0.1-0.05 + idx + 1.2
                    ),
            'min':
                [] if invalid else [perc5],
            'max':
                [] if invalid else [perc95],
            'barmin':
                [] if invalid else [samples.mean() - samples.std()],
            'barmax':
                [] if invalid else [samples.mean() + samples.std()],
            'avg':
                [] if invalid else [samples.mean()],
            'index':
                [] if invalid else [idx + 0.8]
        }

        name_value = [(_w_idx(name, idx), value)
                      for name, value in list(data_dict.items())]
        return name_value

    @on_trait_change('theta_plot_data,theta_samples_valid,redraw')
    def _update_plot_data(self):
        """Updates PlotData on changes."""
        theta = self.model.theta

        plot_data = self.theta_plot_data

        if plot_data is not None:
            for idx, th in enumerate(theta):
                plot_data.set_data('x%d' % idx, [th])
                plot_data.set_data('y%d' % idx, [idx+1.2])

                for name_value in self._samples_names_and_values(idx):
                    name, value = name_value
                    plot_data.set_data(name, value)


    #### View definition #####################################################

    resizable_plot_item = Item(
        'theta_plot',
        editor=ComponentEditor(),
        resizable=True,
        show_label=False,
        width=600,
        height=400
        )

    traits_plot_item = Instance(Item)

    def _traits_plot_item_default(self):
        height = -220 if is_display_small() else -280
        return Item('theta_plot',
                    editor=ComponentEditor(),
                    resizable=False,
                    show_label=False,
                    height=height
                    )
Example #8
0
class ModelA(AbstractModel):
    """Implementation of Model A from (Rzhetsky et al., 2009).

    The model defines a probability distribution over data annotations
    in which each item is annotated by three users. The distributions is
    described according to a three-steps generative model:

        1. First, the model independently generates correctness values for the
        triplet of annotators (e.g., CCI where C=correct, I=incorrect)

        2. Second, the model generates an agreement pattern compatible with
        the correctness values (e.g., CII is compatible with the agreement
        patterns 'abb' and 'abc', where different letters correspond to
        different annotations

        3. Finally, the model generates actual observations compatible with
        the agreement patterns

    The model has two main sets of parameters:

        - theta[j] is the probability that annotator j is correct

        - omega[k] is the probability of observing an annotation of class `k`
          over all items and annotators

    At the moment the implementation of the model assumes 1) a total of 8
    annotators, and 2) each item is annotated by exactly 3 annotators.

    See the documentation for a more detailed description of the model.

    **Reference**

    * Rzhetsky A., Shatkay, H., and Wilbur, W.J. (2009). "How to get the most
      from your curation effort", PLoS Computational Biology, 5(5).
    """

    ######## Model traits

    # number of label classes
    nclasses = Int

    # number of annotators
    nannotators = Int(8)

    # number of annotators rating each item in the loop design
    nannotators_per_item = Int(3)

    #### Model parameters

    # theta[j] is the probability that annotator j is correct
    theta = Array(dtype=float, shape=(None, ))

    # omega[k] is the probability of observing label class k
    omega = Array(dtype=float, shape=(None, ))

    def __init__(self, nclasses, theta, omega, **traits):
        """Create an instance of ModelA.

        Arguments
        ---------
        nclasses : int
            Number of possible annotation classes

        theta : ndarray, shape = (n_annotators, )
            theta[j] is the probability of annotator j being correct

        omega : ndarray, shape = (n_classes, )
            omega[k] is the probability of observing a label of class k
        """

        self.nclasses = nclasses
        self.theta = theta
        self.omega = omega

        super(ModelA, self).__init__(**traits)

    ##### Model and data generation methods ###################################

    @staticmethod
    def create_initial_state(nclasses, theta=None, omega=None):
        """Factory method to create a new model.

        It is often more convenient to use this factory method over the
        constructor, as one does not need to specify the initial model
        parameters.

        If not specified, the parameters theta are drawn from a uniform
        distribution between 0.6 and 0.95 . The parameters omega are drawn
        from a Dirichlet distribution with parameters 2.0 :

        :math:`\\theta_j \sim \mathrm{Uniform}(0.6, 0.95)`

        :math:`\omega_k \sim \mathrm{Dirichlet}(2.0)`


        Arguments
        ---------
        nclasses : int
            number of possible annotation classes

        theta : ndarray, shape = (n_annotators, )
            theta[j] is the probability of annotator j being correct

        omega : ndarray, shape = (n_classes, )
            omega[k] is the probability of observing a label of class k
        """

        if theta is None:
            nannotators = 8
            theta = ModelA._random_theta(nannotators)

        if omega is None:
            omega = ModelA._random_omega(nclasses)

        return ModelA(nclasses, theta, omega)

    @staticmethod
    def _random_theta(nannotators):
        return np.random.uniform(low=0.6, high=0.95, size=(nannotators, ))

    @staticmethod
    def _random_omega(nclasses):
        beta = 2. * np.ones((nclasses, ))
        return np.random.dirichlet(beta)

    def generate_annotations(self, nitems):
        """Generate random annotations from the model.

        The method samples random annotations from the probability
        distribution defined by the model parameters:

            1) generate correct/incorrect labels for the three annotators,
               according to the parameters `theta`

            2) generate agreement patterns (which annotator agrees which whom)
               given the correctness information and the parameters `alpha`

            3) generate the annotations given the agreement patterns and the
               parameters `omega`


        Note that, according to the model's definition, only three annotators
        per item return an annotation. Non-observed annotations have the
        standard value of :attr:`~pyanno.util.MISSING_VALUE`.

        Arguments
        ---------
        nitems : int
            number of annotations to draw from the model

        Returns
        -------
        annotations : ndarray, shape = (n_items, n_annotators)
            annotations[i,j] is the annotation of annotator j for item i
        """
        theta = self.theta
        nannotators = self.nannotators
        nitems_per_loop = np.ceil(float(nitems) / nannotators)

        annotations = np.empty((nitems, nannotators), dtype=int)
        annotations.fill(MISSING_VALUE)

        # loop over annotator triplets (loop design)
        for j in range(nannotators):
            triplet_indices = np.arange(j, j + 3) % self.nannotators
            start_idx = j * nitems_per_loop
            stop_idx = min(nitems, (j + 1) * nitems_per_loop)
            nitems_this_loop = stop_idx - start_idx

            # -- step 1: generate correct / incorrect labels

            # parameters for this triplet
            theta_triplet = self.theta[triplet_indices]
            incorrect = self._generate_incorrectness(nitems_this_loop,
                                                     theta_triplet)

            # -- step 2: generate agreement patterns given correctness
            # convert boolean correctness into combination indices
            # (indices as in Table 3)
            agreement = self._generate_agreement(incorrect)

            # -- step 3: generate annotations
            annotations[start_idx:stop_idx, triplet_indices] = (
                self._generate_annotations(agreement))

        return annotations

    def _generate_incorrectness(self, n, theta_triplet):
        _rnd = np.random.rand(n, self.nannotators_per_item)
        incorrect = _rnd >= theta_triplet
        return incorrect

    def _generate_agreement(self, incorrect):
        """Return indices of agreement pattern given correctness pattern.

        The indices returned correspond to agreement patterns
        as in Table 3: 0=aaa, 1=aaA, 2=aAa, 3=Aaa, 4=Aa@
        """

        # create tensor A_ijk
        # (cf. Table 3 in Rzhetsky et al., 2009, suppl. mat.)
        alpha = self._compute_alpha()
        agreement_tbl = np.array(
            [[1., 0., 0., 0., 0.], [0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.],
             [0., 0., 0., 1., 0.], [0., 0., 0., alpha[0], 1. - alpha[0]],
             [0., 0., alpha[1], 0., 1. - alpha[1]],
             [0., alpha[2], 0., 0., 1. - alpha[2]],
             [alpha[3], alpha[4], alpha[5], alpha[6], 1. - alpha[3:].sum()]])

        # this array maps boolean correctness patterns (e.g., CCI) to
        # indices in the agreement tensor, `agreement_tbl`
        correctness_to_agreement_idx = np.array([0, 3, 2, 6, 1, 5, 4, 7])

        # convert correctness pattern to index in the A_ijk tensor
        correct_idx = correctness_to_agreement_idx[incorrect[:, 0] * 1 +
                                                   incorrect[:, 1] * 2 +
                                                   incorrect[:, 2] * 4]

        # the indices stored in `agreement` correspond to agreement patterns
        # as in Table 3: 0=aaa, 1=aaA, 2=aAa, 3=Aaa, 4=Aa@
        nitems_per_loop = incorrect.shape[0]
        agreement = np.empty((nitems_per_loop, ), dtype=int)
        for i in range(nitems_per_loop):
            # generate agreement pattern according to A_ijk
            agreement[i] = random_categorical(agreement_tbl[correct_idx[i]], 1)

        return agreement

    def _generate_annotations(self, agreement):
        """Generate triplet annotations given agreement pattern."""
        nitems_per_loop = agreement.shape[0]
        omega = self.omega
        annotations = np.empty((nitems_per_loop, 3), dtype=int)

        for i in range(nitems_per_loop):
            # get all compatible annotations
            compatible = _compatibility_tables(self.nclasses)[agreement[i]]
            # compute probability of each possible annotation
            distr = omega[compatible].prod(1)
            distr /= distr.sum()
            # draw annotation
            compatible_idx = random_categorical(distr, 1)[0]
            annotations[i, :] = compatible[compatible_idx, :]
        return annotations

    ##### Parameters estimation methods #######################################

    def mle(self, annotations, estimate_omega=True):
        """Computes maximum likelihood estimate (MLE) of parameters.

        Estimate the parameters :attr:`theta` and :attr:`omega` from a set of
        observed annotations using maximum likelihood estimation.

        Arguments
        ---------
        annotations : ndarray, shape = (n_items, n_annotators)
            annotations[i,j] is the annotation of annotator j for item i

        estimate_omega : bool
            If True, the parameters :attr:`omega` are estimated by the empirical
            class frequency. If False, :attr:`omega` is left unchanged.
        """

        self._raise_if_incompatible(annotations)

        def _wrap_lhood(params, counts):
            self.theta = params
            return -self._log_likelihood_counts(counts)

        self._parameter_estimation(_wrap_lhood,
                                   annotations,
                                   estimate_omega=estimate_omega)

    def map(self, annotations, estimate_omega=True):
        """Computes maximum a posteriori (MAP) estimate of parameters.

        Estimate the parameters :attr:`theta` and :attr:`omega` from a set of
        observed annotations using maximum a posteriori estimation.

        Arguments
        ---------
        annotations : ndarray, shape = (n_items, n_annotators)
            annotations[i,j] is the annotation of annotator j for item i

        estimate_omega : bool
            If True, the parameters :attr:`omega` are estimated by the empirical
            class frequency. If False, :attr:`omega` is left unchanged.
        """

        self._raise_if_incompatible(annotations)

        def _wrap_lhood(params, counts):
            self.theta = params
            return -(self._log_likelihood_counts(counts) + self._log_prior())

        self._parameter_estimation(_wrap_lhood,
                                   annotations,
                                   estimate_omega=estimate_omega)

    def _parameter_estimation(self,
                              objective,
                              annotations,
                              estimate_omega=True):

        counts = compute_counts(annotations, self.nclasses)

        params_start, omega = self._random_initial_parameters(
            annotations, estimate_omega)
        self.omega = omega

        logger.info('Start parameters optimization...')

        params_best = scipy.optimize.fmin(objective,
                                          params_start,
                                          args=(counts, ),
                                          xtol=1e-4,
                                          ftol=1e-4,
                                          disp=False,
                                          maxiter=10000)

        logger.info('Parameters optimization finished')

        self.theta = params_best

    def _random_initial_parameters(self, annotations, estimate_omega):
        # TODO duplication w/ ModelBtLoopDesign
        if estimate_omega:
            # estimate omega from observed annotations
            omega = labels_frequency(annotations, self.nclasses)
        else:
            omega = self.omega

        theta = ModelA._random_theta(self.nannotators)
        return theta, omega

    ##### Model likelihood methods ############################################

    def log_likelihood(self, annotations):
        """Compute the log likelihood of a set of annotations given the model.

        Returns :math:`\log P(\mathbf{x} | \omega, \\theta)`,
        where :math:`\mathbf{x}` is the array of annotations.

        Arguments
        ---------
        annotations : ndarray, shape = (n_items, n_annotators)
            annotations[i,j] is the annotation of annotator j for item i

        Returns
        -------
        log_lhood : float
            log likelihood of `annotations`
        """

        self._raise_if_incompatible(annotations)

        counts = compute_counts(annotations, self.nclasses)
        return self._log_likelihood_counts(counts)

    # TODO code duplication with ModelBtLoopDesign -> refactor
    def _log_likelihood_counts(self, counts):
        """Compute the log likelihood of annotations given the model.

        This method assumes the data is in counts format.
        """

        # TODO: check if it's possible to replace these constraints with bounded optimization
        # check bounds of parameters (for likelihood optimization)
        if np.amin(self.theta) <= 0 or np.amax(self.theta) > 1:
            # return np.inf
            return SMALLEST_FLOAT

        # compute alpha and beta (they do not depend on theta)
        alpha = self._compute_alpha()
        beta = [None] * 5
        pattern_to_indices = _compatibility_tables(self.nclasses)
        for pattern in range(5):
            indices = pattern_to_indices[pattern]
            beta[pattern] = self.omega[indices].prod(1)
            beta[pattern] /= beta[pattern].sum()

        llhood = 0.
        # loop over the 8 combinations of annotators
        for i in range(8):
            # extract the theta parameters for this triplet
            triplet_indices = np.arange(i, i + 3) % self.nannotators
            triplet_indices.sort()
            theta_triplet = self.theta[triplet_indices]

            # compute the likelihood for the triplet
            llhood += self._log_likelihood_triplet(counts[:, i], theta_triplet,
                                                   alpha, beta)

        return llhood

    def _log_likelihood_triplet(self, counts_triplet, theta_triplet, alpha,
                                beta):
        """Compute the log likelihood of data for one triplet of annotators.

        Input:
        counts_triplet -- count data for one combination of annotators
        theta_triplet -- theta parameters of the current triplet
        """

        nclasses = self.nclasses
        llhood = 0.

        # loop over all possible agreement patterns
        # 0=aaa, 1=aaA, 2=aAa, 3=Aaa, 4=Aa@

        pattern_to_indices = _compatibility_tables(nclasses)
        for pattern in range(5):
            # P( A_ijk | T_ijk ) * P( T_ijk )  , or "alpha * theta triplet"
            prob = self._prob_a_and_t(pattern, theta_triplet, alpha)

            # P( V_ijk ! A_ijk) * P( A_ijk | T_ijk ) * P( T_ijk )
            #   = P( V_ijk | A, T, model)
            prob *= beta[pattern]

            # P( V_ijk | model ) = sum over A and T of conditional probability
            indices = pattern_to_indices[pattern]
            count_indices = _triplet_to_counts_index(indices, nclasses)
            log_prob = ninf_to_num(np.log(prob))

            llhood += (counts_triplet[count_indices] * log_prob).sum()

        return llhood

    def _log_prior(self):
        """Compute log probability of prior on the theta parameters."""
        log_prob = scipy.stats.beta._logpdf(self.theta, 2., 1.).sum()
        if np.isneginf(log_prob):
            log_prob = SMALLEST_FLOAT
        return log_prob

    def _prob_a_and_t(self, pattern, theta_triplet, alpha):
        # TODO make more robust by taking logarithms earlier
        # TODO could be vectorized some more using the A_ijk tensor
        # 0=aaa, 1=aaA, 2=aAa, 3=Aaa, 4=Aa@

        # abbreviations
        thetat = theta_triplet
        not_thetat = (1. - theta_triplet)

        if pattern == 0:  # aaa patterns
            prob = (thetat.prod() + not_thetat.prod() * alpha[3])

        elif pattern == 1:  # aaA patterns
            prob = (thetat[0] * thetat[1] * not_thetat[2] +
                    not_thetat[0] * not_thetat[1] * thetat[2] * alpha[2] +
                    not_thetat[0] * not_thetat[1] * not_thetat[2] * alpha[4])

        elif pattern == 2:  # aAa patterns
            prob = (thetat[0] * not_thetat[1] * thetat[2] +
                    not_thetat[0] * thetat[1] * not_thetat[2] * alpha[1] +
                    not_thetat[0] * not_thetat[1] * not_thetat[2] * alpha[5])

        elif pattern == 3:  # Aaa patterns
            prob = (not_thetat[0] * thetat[1] * thetat[2] +
                    thetat[0] * not_thetat[1] * not_thetat[2] * alpha[0] +
                    not_thetat[0] * not_thetat[1] * not_thetat[2] * alpha[6])

        elif pattern == 4:  # Aa@ pattern
            prob = (not_thetat[0] * not_thetat[1] * not_thetat[2] *
                    (1. - alpha[3] - alpha[4] - alpha[5] - alpha[6]) +
                    thetat[0] * not_thetat[1] * not_thetat[2] *
                    (1. - alpha[0]) +
                    not_thetat[0] * thetat[1] * not_thetat[2] *
                    (1. - alpha[1]) +
                    not_thetat[0] * not_thetat[1] * thetat[2] *
                    (1. - alpha[2]))

        return prob

    ##### Sampling posterior over parameters ##################################

    def sample_posterior_over_accuracy(self,
                                       annotations,
                                       nsamples,
                                       burn_in_samples=100,
                                       thin_samples=5,
                                       target_rejection_rate=0.3,
                                       rejection_rate_tolerance=0.2,
                                       step_optimization_nsamples=500,
                                       adjust_step_every=100):
        """Return samples from posterior distribution over theta given data.

        Samples are drawn using a variant of a Metropolis-Hasting Markov Chain
        Monte Carlo (MCMC) algorithm. Sampling proceeds in two phases:

            1) *step size estimation phase*: first, the step size in the
               MCMC algorithm is adjusted to achieve a given rejection rate.

            2) *sampling phase*: second, samples are collected using the
               step size from phase 1.

        Arguments
        ---------
        annotations : ndarray, shape = (n_items, n_annotators)
            annotations[i,j] is the annotation of annotator j for item i

        nsamples : int
            number of samples to draw from the posterior

        burn_in_samples : int
            Discard the first `burn_in_samples` during the initial burn-in
            phase, where the Monte Carlo chain converges to the posterior

        thin_samples : int
            Only return one every `thin_samples` samples in order to reduce
            the auto-correlation in the sampling chain. This is called
            "thinning" in MCMC parlance.

        target_rejection_rate : float
            target rejection rate for the step size estimation phase

        rejection_rate_tolerance : float
            the step size estimation phase is ended when the rejection rate for
            all parameters is within `rejection_rate_tolerance` from
            `target_rejection_rate`

        step_optimization_nsamples : int
            number of samples to draw in the step size estimation phase

        adjust_step_every : int
            number of samples after which the step size is adjusted during
            the step size estimation pahse

        Returns
        -------
        samples : ndarray, shape = (n_samples, n_annotators)
            samples[i,:] is one sample from the posterior distribution over the
            parameters `theta`
        """

        self._raise_if_incompatible(annotations)
        nsamples = self._compute_total_nsamples(nsamples, burn_in_samples,
                                                thin_samples)

        # optimize step size
        counts = compute_counts(annotations, self.nclasses)

        # wrap log likelihood function to give it to optimize_step_size and
        # sample_distribution
        _llhood_counts = self._log_likelihood_counts
        _log_prior = self._log_prior

        def _wrap_llhood(params, counts):
            self.theta = params
            return _llhood_counts(counts) + _log_prior()

        # TODO this save-reset is rather ugly, refactor: create copy of
        #      model and sample over it
        # save internal parameters to reset at the end of sampling
        save_params = self.theta
        try:
            # compute optimal step size for given target rejection rate
            params_start = self.theta.copy()
            params_upper = np.ones((self.nannotators, ))
            params_lower = np.zeros((self.nannotators, ))
            step = optimize_step_size(_wrap_llhood, params_start, counts,
                                      params_lower, params_upper,
                                      step_optimization_nsamples,
                                      adjust_step_every, target_rejection_rate,
                                      rejection_rate_tolerance)

            # draw samples from posterior distribution over theta
            samples = sample_distribution(_wrap_llhood, params_start, counts,
                                          step, nsamples, params_lower,
                                          params_upper)

            return self._post_process_samples(samples, burn_in_samples,
                                              thin_samples)
        finally:
            # reset parameters
            self.theta = save_params

    ##### Posterior distributions #############################################

    # TODO ideally, one would infer the posterior over correctness (T_ijk)
    #   first, and then return the probability of each value
    #    def infer_correctness(self, annotations):
    #        """Infer posterior distribution over correctness patterns."""
    #        nitems = annotations.shape[0]
    #        nclasses = self.nclasses
    #
    #        posterior = np.zeros((nitems, self.annotators_per_item**2))
    #        alpha = self._compute_alpha()
    #        for i, row in enumerate(annotations):
    #            valid_idx = np.where(row >= 0)
    #            vijk = row[valid_idx]
    #            tijk = self.theta[valid_idx]
    #            p = self._compute_posterior_T_triplet(vijk, tijk, alpha)
    #            posteriors[i, :] = p
    #
    #        return posteriors
    #
    #
    #    def _compute_posterior_T_triplet(self, v, t, alpha):
    #        # switch over agreement pattern
    #        # 0=aaa, 1=aaA, 2=aAa, 3=Aaa, 4=Aa@
    #        if v[0] == v[1] == v[2]:  # aaa pattern
    #            pass

    def infer_labels(self, annotations):
        """Infer posterior distribution over label classes.

        Compute the posterior distribution over label classes given observed
        annotations, :math:`P( \mathbf{y} | \mathbf{x}, \\theta, \omega)`.

        Arguments
        ---------
        annotations : ndarray, shape = (n_items, n_annotators)
            annotations[i,j] is the annotation of annotator j for item i

        Returns
        -------
        posterior : ndarray, shape = (n_items, n_classes)
            posterior[i,k] is the posterior probability of class k given the
            annotation observed in item i.
        """

        self._raise_if_incompatible(annotations)

        nitems = annotations.shape[0]
        nclasses = self.nclasses

        posteriors = np.zeros((nitems, nclasses))
        alpha = self._compute_alpha()
        i = 0
        for row in annotations:
            ind = np.where(row >= 0)
            vijk = row[ind]
            tijk = self.theta[ind].copy()
            p = self._compute_posterior_triplet(vijk, tijk, alpha)
            posteriors[i, :] = p
            i += 1

        return posteriors

    def _compute_posterior_triplet(self, vijk, tijk, alpha):
        nclasses = self.nclasses
        posteriors = np.zeros(nclasses, float)

        #-----------------------------------------------
        # aaa
        if vijk[0] == vijk[1] and vijk[1] == vijk[2]:
            x1 = tijk[0] * tijk[1] * tijk[2]
            x2 = (1 - tijk[0]) * (1 - tijk[1]) * (1 - tijk[2])
            p1 = x1 / (x1 + alpha[3] * x2)
            p2 = (1 - p1) / (nclasses - 1)

            for j in range(nclasses):
                if vijk[0] == j:
                    posteriors[j] = p1
                else:
                    posteriors[j] = p2

        #-----------------------------------------------
        # aaA
        elif vijk[0] == vijk[1] and vijk[1] != vijk[2]:
            x1 = tijk[0] * tijk[1] * (1 - tijk[2])
            x2 = (1 - tijk[0]) * (1 - tijk[1]) * tijk[2]
            x3 = (1 - tijk[0]) * (1 - tijk[1]) * (1 - tijk[2])

            # a is correct
            p1 = x1 / (x1 + alpha[2] * x2 + alpha[4] * x3)

            # A is correct
            p2 = (alpha[2] * x2) / (x1 + alpha[2] * x2 + alpha[4] * x3)

            # neither
            p3 = (1 - p1 - p2) / (nclasses - 2)

            for j in range(nclasses):
                if vijk[0] == j:
                    posteriors[j] = p1
                elif vijk[2] == j:
                    posteriors[j] = p2
                else:
                    posteriors[j] = p3

        #-----------------------------------------------
        # aAa
        elif vijk[0] == vijk[2] and vijk[1] != vijk[2]:
            x1 = tijk[0] * (1 - tijk[1]) * tijk[2]
            x2 = (1 - tijk[0]) * tijk[1] * (1 - tijk[2])
            x3 = (1 - tijk[0]) * (1 - tijk[1]) * (1 - tijk[2])

            # a is correct
            p1 = x1 / (x1 + alpha[1] * x2 + alpha[5] * x3)

            # A is correct
            p2 = (alpha[1] * x2) / (x1 + alpha[1] * x2 + alpha[5] * x3)

            # neither
            p3 = (1 - p1 - p2) / (nclasses - 2)

            for j in range(nclasses):
                if vijk[0] == j:
                    posteriors[j] = p1
                elif vijk[1] == j:
                    posteriors[j] = p2
                else:
                    posteriors[j] = p3

        #-----------------------------------------------
        # Aaa
        elif vijk[1] == vijk[2] and vijk[0] != vijk[2]:
            x1 = (1 - tijk[0]) * tijk[1] * tijk[2]
            x2 = tijk[0] * (1 - tijk[1]) * (1 - tijk[2])
            x3 = (1 - tijk[0]) * (1 - tijk[1]) * (1 - tijk[2])

            # a is correct
            p1 = x1 / (x1 + alpha[0] * x2 + alpha[6] * x3)

            # A is correct
            p2 = (alpha[0] * x2) / (x1 + alpha[0] * x2 + alpha[6] * x3)

            # neither
            p3 = (1 - p1 - p2) / (nclasses - 2)

            for j in range(nclasses):
                if vijk[0] == j:
                    posteriors[j] = p2
                elif vijk[2] == j:
                    posteriors[j] = p1
                else:
                    posteriors[j] = p3

        #-----------------------------------------------
        # aAb
        elif vijk[0] != vijk[1] and vijk[1] != vijk[2]:
            x1 = tijk[0] * (1 - tijk[1]) * (1 - tijk[2])
            x2 = (1 - tijk[0]) * tijk[1] * (1 - tijk[2])
            x3 = (1 - tijk[0]) * (1 - tijk[1]) * tijk[2]
            x4 = (1 - tijk[0]) * (1 - tijk[1]) * (1 - tijk[2])

            summa1 = 1 - alpha[3] - alpha[4] - alpha[5] - alpha[6]
            summa2 = ((1 - alpha[0]) * x1 + (1 - alpha[1]) * x2 +
                      (1 - alpha[2]) * x3 + summa1 * x4)

            # a is correct
            p1 = (1 - alpha[0]) * x1 / summa2

            # A is correct
            p2 = (1 - alpha[1]) * x2 / summa2

            # b is correct
            p3 = (1 - alpha[2]) * x3 / summa2

            # (a, A, b) are all incorrect
            p4 = (summa1 * x4 / summa2) / (nclasses - 3)

            for j in range(nclasses):
                if vijk[0] == j:
                    posteriors[j] = p1
                elif vijk[1] == j:
                    posteriors[j] = p2
                elif vijk[2] == j:
                    posteriors[j] = p3
                else:
                    posteriors[j] = p4

        # check posteriors: non-negative, sum to 1
        assert np.abs(posteriors.sum() - 1.) < 1e-6
        assert posteriors.min() >= 0.

        return posteriors

    def _compute_alpha(self):
        """Compute the parameters `alpha` given the parameters `omega`.

        Cf. Table 4 in Rzhetsky et al., 2009.
        """

        omega = self.omega
        nclasses = self.nclasses
        alpha = np.zeros((7, ))

        # ------ alpha_1,2,3

        # sum over all doublets
        outer_omega = np.outer(omega, omega)
        sum_wi_wk = outer_omega.sum()

        # sum over all omega_i * omega_j, where i!=k and j!=k
        sum_wi_wj_not_k = np.zeros((nclasses, ))
        # sum over all omega_i ** 2, where i!=k
        sum_wi2_not_k = np.zeros((nclasses, ))

        for k in range(nclasses):
            sum_wi_wj_not_k[k] = (sum_wi_wk - 2 * outer_omega[:, k].sum() +
                                  outer_omega[k, k])
            sum_wi2_not_k[k] = (outer_omega.diagonal().sum() -
                                outer_omega[k, k])

        a1 = (omega * sum_wi2_not_k / sum_wi_wj_not_k).sum()
        alpha[0:3] = a1

        # ------ alpha_4,5,6,7

        # sum over all triplets
        outer_omega3 = (outer_omega[:, :, np.newaxis] *
                        omega[np.newaxis, np.newaxis, :])
        sum_wi_wj_wl = outer_omega3.sum()

        # sum over omega_i * omega_j * omega_l, where i!=k and j!=k and l!=k
        sum_wi_wj_wl_not_k = np.zeros((nclasses, ))
        for k in range(nclasses):
            sum_wi_wj_wl_not_k[k] = (sum_wi_wj_wl -
                                     3. * outer_omega3[:, :, k].sum() +
                                     3. * outer_omega3[:, k, k].sum() -
                                     outer_omega3[k, k, k])
        omega3 = omega**3
        sum_wi3_not_k = omega3.sum() - omega3

        a4 = (omega * sum_wi3_not_k / sum_wi_wj_wl_not_k).sum()
        alpha[3] = a4

        a5 = 0
        for i in range(nclasses):
            tmp = 0
            for j in range(nclasses):
                for k in range(nclasses):
                    if j != i and k != i and j != k:
                        tmp += omega[k] * omega[j]**2
            a5 += omega[i] * tmp / sum_wi_wj_wl_not_k[i]

        alpha[4:7] = a5

        return alpha

    ##### Verify input ########################################################

    def are_annotations_compatible(self, annotations):
        """Check if the annotations are compatible with the models' parameters.
        """

        if not super(ModelA, self).are_annotations_compatible(annotations):
            return False

        masked_annotations = np.ma.masked_equal(annotations, MISSING_VALUE)

        # exactly 3 annotations per row
        nvalid = (~masked_annotations.mask).sum(1)
        if not np.all(nvalid == self.nannotators_per_item):
            return False

        return True