示例#1
0
class ViewerCreator(TemplateMixin):
    template = load_template("viewer_creator.vue", __file__).tag(sync=True)
    viewer_types = List([]).tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # Load in the references to the viewer registry. Because traitlets
        #  can't serialize the actual viewer class reference, create a list of
        #  dicts containing just the viewer name and label.
        self.viewer_types = [{
            'name': k,
            'label': v['label']
        } for k, v in viewer_registry.members.items()]

    def vue_create_viewer(self, name):
        viewer_cls = viewer_registry.members[name]['cls']

        # selected = self.components.get('g-data-tree').selected

        # for idx in selected:
        #     data = validate_data_argument(self.data_collection,
        #                                   self.data_collection[idx])

        new_viewer_message = NewViewerMessage(viewer_cls,
                                              data=None,
                                              sender=self)

        self.hub.broadcast(new_viewer_message)
示例#2
0
class UnifiedSlider(TemplateMixin):
    template = load_template("unified_slider.vue", __file__).tag(sync=True)
    slider = Any(0).tag(sync=True)
    min_value = Float(0).tag(sync=True)
    max_value = Float(100).tag(sync=True)
    linked = Bool(True).tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._watched_viewers = []

        # Listen for add data events. **Note** this should only be used in
        #  cases where there is a specific type of data expected and arbitrary
        #  viewers are not expected to be created. That is, the expected data
        #  in _all_ viewers should be uniform.
        self.session.hub.subscribe(self,
                                   AddDataMessage,
                                   handler=self._on_data_added)

    @observe("linked")
    def _on_linked_changed(self, event):
        for viewer in self._watched_viewers:

            if not event['new']:
                viewer.state.remove_callback('slices',
                                             self._slider_value_updated)
            else:
                viewer.state.add_callback('slices', self._slider_value_updated)

    def _on_data_added(self, msg):
        if len(msg.data.shape) == 3 and \
                isinstance(msg.viewer, BqplotImageView):
            self.max_value = msg.data.shape[0] - 1

            if msg.viewer not in self._watched_viewers:
                self._watched_viewers.append(msg.viewer)

                msg.viewer.state.add_callback('slices',
                                              self._slider_value_updated)

    def _slider_value_updated(self, value):
        if len(value) > 0:
            self.slider = float(value[0])

    @observe('slider')
    def _on_slider_updated(self, event):
        if not event['new']:
            value = 0
        else:
            value = int(event['new'])

        if self.linked:
            for viewer in self._watched_viewers:
                viewer.state.slices = (value, 0, 0)
示例#3
0
class SubsetTools(TemplateMixin):
    template = load_template("subset_tools.vue", __file__).tag(sync=True)
    select = List([]).tag(sync=True)
    subset_mode = Int(0).tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.components = {
            'g-subset-select': SubsetSelect(session=self.session),
            'g-subset-mode': SelectionModeMenu(session=self.session)
        }
示例#4
0
class ImageViewerCreator(TemplateMixin):

    template = load_template("image_viewer_creator.vue",
                             __file__).tag(sync=True)
    viewer_types = List([]).tag(sync=True)

    def vue_create_image_viewer(self, *args, **kwargs):

        new_viewer_message = NewViewerMessage(ImvizImageView,
                                              data=None,
                                              sender=self)

        self.hub.broadcast(new_viewer_message)
示例#5
0
class GaussianSmoothingButton(TemplateMixin):
    dialog = Bool(False).tag(sync=True)
    template = load_template("gaussian_smoothing.vue", __file__).tag(sync=True)
    stddev = Unicode().tag(sync=True)
    dc_items = List([]).tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.hub.subscribe(self,
                           DataCollectionAddMessage,
                           handler=self._on_data_updated)
        self.hub.subscribe(self,
                           DataCollectionDeleteMessage,
                           handler=self._on_data_updated)

        self._selected_data = None

    def _on_data_updated(self, msg):
        self.dc_items = [x.label for x in self.data_collection]

    def vue_data_selected(self, event):
        self._selected_data = next(
            (x for x in self.data_collection if x.label == event))

    def vue_gaussian_smooth(self, *args, **kwargs):
        # Testing inputs to make sure putting smoothed spectrum into
        # datacollection works
        # input_flux = Quantity(np.array([0.2, 0.3, 2.2, 0.3]), u.Jy)
        # input_spaxis = Quantity(np.array([1, 2, 3, 4]), u.micron)
        # spec1 = Spectrum1D(input_flux, spectral_axis=input_spaxis)
        size = float(self.stddev)
        spec = self._selected_data.get_object(cls=Spectrum1D)

        # Takes the user input from the dialog (stddev) and uses it to
        # define a standard deviation for gaussian smoothing
        spec_smoothed = gaussian_smooth(spec, stddev=size)

        self.data_collection[
            f"Smoothed {self._selected_data.label}"] = spec_smoothed

        self.dialog = False
示例#6
0
class DataTools(TemplateMixin):
    template = load_template("data_tools.vue", __file__).tag(sync=True)
    dialog = Bool(False).tag(sync=True)
    valid_path = Bool(True).tag(sync=True)
    error_message = Unicode().tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        start_path = os.environ.get('JDAVIZ_START_DIR', os.path.curdir)

        self._file_upload = FileChooser(start_path)

        self.components = {'g-file-import': self._file_upload}

        self._file_upload.observe(self._on_file_path_changed,
                                  names='file_path')

    def _on_file_path_changed(self, event):
        if (self._file_upload.file_path is not None
                and not os.path.exists(self._file_upload.file_path)
                or not os.path.isfile(self._file_upload.file_path)):
            self.error_message = "No file exists at given path"
            self.valid_path = False
        else:
            self.error_message = ""
            self.valid_path = True

    def vue_load_data(self, *args, **kwargs):
        if self._file_upload.file_path is None:
            self.error_message = "No file selected"
        elif os.path.exists(self._file_upload.file_path):
            try:
                load_data_message = LoadDataMessage(
                    self._file_upload.file_path, sender=self)
                self.hub.broadcast(load_data_message)
            except Exception:
                self.error_message = "An error occurred when loading the file"
            else:
                self.dialog = False
示例#7
0
class LineListTool(TemplateMixin):
    dialog = Bool(False).tag(sync=True)
    template = load_template("line_lists.vue", __file__).tag(sync=True)
    dc_items = List([]).tag(sync=True)
    available_lists = List([]).tag(sync=True)
    loaded_lists = List([]).tag(sync=True)
    list_contents = Dict({}).tag(sync=True)
    custom_name = Unicode().tag(sync=True)
    custom_rest = Unicode().tag(sync=True)
    custom_unit = Unicode().tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._viewer = self.app.get_viewer("spectrum-viewer")
        self._viewer_spectrum = None
        self._spectrum1d = None
        self.available_lists = self._viewer.available_linelists()
        self.list_to_load = None
        self.loaded_lists = ["Custom"]
        self.list_contents = {"Custom": {"lines": [], "color": "#FF0000FF"}}
        self.line_mark_dict = {}
        self._units = {}
        self._bounds = {}

        self.hub.subscribe(self,
                           AddDataMessage,
                           handler=self._on_viewer_data_changed)

        self.hub.subscribe(self,
                           RemoveDataMessage,
                           handler=self._on_viewer_data_changed)

        self.hub.subscribe(self,
                           SubsetCreateMessage,
                           handler=lambda x: self._on_viewer_data_changed())

        self.hub.subscribe(self,
                           SubsetDeleteMessage,
                           handler=lambda x: self._on_viewer_data_changed())

        self.hub.subscribe(self,
                           SubsetUpdateMessage,
                           handler=lambda x: self._on_viewer_data_changed())

        self.hub.subscribe(self,
                           AddLineListMessage,
                           handler=self._list_from_notebook)

    def _on_viewer_data_changed(self, msg=None):
        """
        Callback method for when data is added or removed from a viewer, or
        when a subset is created, deleted, or updated. This method receieves
        a glue message containing viewer information in the case of the former
        set of events, and updates the units in which to display the lines.

        Notes
        -----
        We do not attempt to parse any data at this point, at it can cause
        visible lag in the application.

        Parameters
        ----------
        msg : `glue.core.Message`
            The glue message passed to this callback method.
        """
        self._viewer_id = self.app._viewer_item_by_reference(
            'spectrum-viewer').get('id')

        # Subsets are global and are not linked to specific viewer instances,
        # so it's not required that we match any specific ids for that case.
        # However, if the msg is not none, check to make sure that it's the
        # viewer we care about.
        if msg is not None and msg.viewer_id != self._viewer_id:
            return

        try:
            viewer_data = self.app.get_viewer('spectrum-viewer').data()
        except TypeError:
            warn_message = SnackbarMessage(
                "Line list plugin could not retrieve data from viewer",
                sender=self,
                color="error")
            self.hub.broadcast(warn_message)
            return

        # If no data is currently plotted, don't attempt to update
        if viewer_data is None or len(viewer_data) == 0:
            return

        self._viewer_spectrum = viewer_data[0]

        self._units["x"] = str(self._viewer_spectrum.spectral_axis.unit)
        self._units["y"] = str(self._viewer_spectrum.flux.unit)

        self._bounds["min"] = self._viewer_spectrum.spectral_axis[0]
        self._bounds["max"] = self._viewer_spectrum.spectral_axis[-1]

    def _list_from_notebook(self, msg):
        """
        Callback method for when a spectral line list is added to the specviz instance from the notebook.

        Parameters
        ----------
        msg : `glue.core.Message`
            The glue message passed to this callback method. Includes the line data added in msg.table.
        """
        #list_contents = self.list_contents
        #loaded_lists = self.loaded_lists
        loaded_lists = ["Custom"]
        list_contents = {"Custom": {"lines": [], "color": "#FF0000FF"}}
        for row in msg.table:
            if row["listname"] not in loaded_lists:
                loaded_lists.append(row["listname"])
            if row["listname"] not in list_contents:
                list_contents[row["listname"]] = {
                    "lines": [],
                    "color": "#FF0000FF"
                }
            temp_dict = {
                "linename": row["linename"],
                "rest": row["rest"].value,
                "unit": str(row["rest"].unit),
                "colors": row["colors"] if "colors" in row else "#FF0000FF",
                "show": True,
                "name_rest": row["name_rest"]
            }
            list_contents[row["listname"]]["lines"].append(temp_dict)

        self.loaded_lists = []
        self.loaded_lists = loaded_lists
        self.list_contents = {}
        self.list_contents = list_contents

        lines_loaded_message = SnackbarMessage(
            "Spectral lines loaded from notebook",
            sender=self,
            color="success")
        self.hub.broadcast(lines_loaded_message)

    def vue_update_available(self):
        """
        Check that the list to select from is up to date
        """
        self.available_lists = get_available_linelists()

    def update_line_mark_dict(self):
        self.line_mark_dict = {}
        for m in self._viewer.figure.marks:
            if type(m) == SpectralLine:
                self.line_mark_dict[m.table_index] = m

    def vue_list_selected(self, event):
        """
        Handle list selection from presets dropdown selector
        """
        self.list_to_load = event

    def vue_load_list(self, event):
        """
        Load one of the preset line lists, storing it's info in a
        vuetify-friendly manner in addition to loading the astropy table into
        the viewer's spectral_lines attribute.
        """
        # Don't need to reload an already loaded list
        if self.list_to_load in self.loaded_lists:
            return
        temp_table = load_preset_linelist(self.list_to_load)

        # Also store basic list contents in a form that vuetify can handle
        # Adds line style parameters that can be changed on the front end
        temp_table["colors"] = "#FF0000FF"

        # Load the table into the main astropy table and get it back, to make
        # sure all values match between the main table and local plugin
        temp_table = self._viewer.load_line_list(temp_table, return_table=True)

        line_list_dict = {"lines": [], "color": "#FF000080"}
        #extra_fields = [x for x in temp_table.colnames if x not in
        #                ("linename", "rest", "name_rest")]

        for row in temp_table:
            temp_dict = {
                "linename": row["linename"],
                "rest": row["rest"].value,
                "unit": str(row["rest"].unit),
                "colors": row["colors"],
                "show": True,
                "name_rest": str(row["name_rest"])
            }
            #for field in extra_fields:
            #    temp_dict[field] = row[field]
            line_list_dict["lines"].append(temp_dict)

        list_contents = self.list_contents
        list_contents[self.list_to_load] = line_list_dict
        self.list_contents = {}
        self.list_contents = list_contents

        loaded_lists = self.loaded_lists + [self.list_to_load]
        self.loaded_lists = []
        self.loaded_lists = loaded_lists

        self._viewer.plot_spectral_lines()
        self.update_line_mark_dict()

        lines_loaded_message = SnackbarMessage(
            "Spectral lines loaded from preset", sender=self, color="success")
        self.hub.broadcast(lines_loaded_message)

    def vue_add_custom_line(self, event):
        """
        Add a line to the "Custom" line list from UI input
        """
        list_contents = self.list_contents
        temp_dict = {
            "linename": self.custom_name,
            "rest": float(self.custom_rest),
            "unit": self.custom_unit,
            "colors": list_contents["Custom"]["color"],
            "show": True
        }

        # Add to viewer astropy table
        temp_table = QTable()
        temp_table["linename"] = [temp_dict["linename"]]
        temp_table["rest"] = [temp_dict["rest"] * u.Unit(temp_dict["unit"])]
        temp_table["colors"] = [temp_dict["colors"]]
        temp_table = self._viewer.load_line_list(temp_table, return_table=True)

        # Add line to Custom lines in local list
        temp_dict["name_rest"] = str(temp_table[0]["name_rest"])
        list_contents["Custom"]["lines"].append(temp_dict)
        self.list_contents = {}
        self.list_contents = list_contents

        self._viewer.plot_spectral_line(temp_dict["name_rest"])
        self.update_line_mark_dict()

        lines_loaded_message = SnackbarMessage("Custom spectral line loaded",
                                               sender=self,
                                               color="success")
        self.hub.broadcast(lines_loaded_message)

    def vue_show_all_in_list(self, listname):
        """
        Toggle all lines in list to be visible
        """
        lc = self.list_contents
        for line in lc[listname]["lines"]:
            line["show"] = True
            self._viewer.spectral_lines.loc[line["name_rest"]]["show"] = True
        # Trick traitlets into updating
        self.list_contents = {}
        self.list_contents = lc

        self._viewer.plot_spectral_lines()
        self.update_line_mark_dict()

    def vue_hide_all_in_list(self, listname):
        """
        Toggle all lines in list to be hidden
        """
        lc = self.list_contents
        name_rests = []
        for line in lc[listname]["lines"]:
            line["show"] = False
            name_rests.append(line["name_rest"])
        # Trick traitlets into updating
        self.list_contents = {}
        self.list_contents = lc

        self._viewer.erase_spectral_lines(name_rest=name_rests)
        self.update_line_mark_dict()

    def vue_plot_all_lines(self, event):
        """
        Plot all the currently loaded lines in the viewer
        """
        if self._viewer.spectral_lines is None:
            warn_message = SnackbarMessage("No spectral lines loaded to plot",
                                           sender=self,
                                           color="error")
            self.hub.broadcast(warn_message)
            return
        lc = self.list_contents
        for listname in lc:
            for line in lc[listname]["lines"]:
                line["show"] = True
        self._viewer.spectral_lines["show"] = True
        # Trick traitlets into updating
        self.list_contents = {}
        self.list_contents = lc

        self._viewer.plot_spectral_lines()
        self.update_line_mark_dict()

    def vue_erase_all_lines(self, event):
        """
        Erase all lines from the viewer
        """
        if self._viewer.spectral_lines is None:
            warn_message = SnackbarMessage("No spectral lines to erase",
                                           sender=self,
                                           color="error")
            self.hub.broadcast(warn_message)
            return
        lc = self.list_contents
        for listname in lc:
            for line in lc[listname]["lines"]:
                line["show"] = False
        # Trick traitlets into updating
        self.list_contents = {}
        self.list_contents = lc

        self._viewer.erase_spectral_lines()

    def vue_change_visible(self, line):
        """
        Plot or erase a single line as needed when "Visible" checkbox is changed
        """
        name_rest = line["name_rest"]
        if line["show"]:
            self._viewer.plot_spectral_line(name_rest)
            self.update_line_mark_dict()
        else:
            self._viewer.erase_spectral_lines(name_rest=name_rest)

    def vue_set_color(self, data):
        """
        Change the color either of all members of a line list, or of an
        individual line.
        """
        color = data['color']
        if "listname" in data:
            listname = data["listname"]
            lc = self.list_contents[listname]
            lc["color"] = color

            for line in lc["lines"]:
                line["colors"] = color
                # Update the astropy table entry
                name_rest = line["name_rest"]
                self._viewer.spectral_lines.loc[name_rest]["colors"] = color
                # Update the color on the plot
                if name_rest in self.line_mark_dict:
                    self.line_mark_dict[name_rest].colors = [color]

        elif "linename" in data:
            pass

    def vue_remove_list(self, listname):
        """
        Method to remove line list from available expansion panels when the x
        on the panel header is clicked. Also removes line marks from plot and
        updates the "show" value in the astropy table to False..
        """
        lc = self.list_contents[listname]
        name_rests = []
        for line in lc["lines"]:
            name_rests.append(self.vue_remove_line(line, erase=False))
        self._viewer.erase_spectral_lines(name_rest=name_rests)

        self.loaded_lists = [x for x in self.loaded_lists if x != listname]
        del (self.list_contents[listname])

    def vue_remove_line(self, line, erase=True):
        """
        Method to remove a line from the plot when the line is deselected in
        the expansion panel content. Input must have "linename" and "rest"
        values for indexing on the astropy table.
        """
        name_rest = line["name_rest"]
        # Keep in our spectral line astropy table, but set it to not show on plot
        self._viewer.spectral_lines.loc[name_rest]["show"] = False

        # Remove the line from the plot marks
        if erase:
            try:
                self._viewer.erase_spectral_lines(name_rest=name_rest)
                del (self.line_mark_dict[name_rest])
            except KeyError:
                raise KeyError("line marks: {}".format(
                    self._viewer.figure.marks))
        else:
            return name_rest
示例#8
0
class LineAnalysis(TemplateMixin):
    dialog = Bool(False).tag(sync=True)
    template = load_template("line_analysis.vue", __file__).tag(sync=True)
    dc_items = List([]).tag(sync=True)
    temp_function = Unicode().tag(sync=True)
    available_functions = List(list(FUNCTIONS.keys())).tag(sync=True)
    result_available = Bool(False).tag(sync=True)
    results = List().tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._viewer_spectra = None
        self._spectrum1d = None
        self._units = {}
        self.result_available = False

        self.hub.subscribe(self,
                           AddDataMessage,
                           handler=self._on_viewer_data_changed)

        self.hub.subscribe(self,
                           RemoveDataMessage,
                           handler=self._on_viewer_data_changed)

        self.hub.subscribe(self,
                           SubsetCreateMessage,
                           handler=lambda x: self._on_viewer_data_changed())

        self.hub.subscribe(self,
                           SubsetDeleteMessage,
                           handler=lambda x: self._on_viewer_data_changed())

        self.hub.subscribe(self,
                           SubsetUpdateMessage,
                           handler=lambda x: self._on_viewer_data_changed())

    def _on_viewer_data_changed(self, msg=None):
        """
        Callback method for when data is added or removed from a viewer, or
        when a subset is created, deleted, or updated. This method receieves
        a glue message containing viewer information in the case of the former
        set of events, and updates the available data list displayed to the
        user.

        Notes
        -----
        We do not attempt to parse any data at this point, at it can cause
        visible lag in the application.

        Parameters
        ----------
        msg : `glue.core.Message`
            The glue message passed to this callback method.
        """
        self._viewer_id = self.app._viewer_item_by_reference(
            'spectrum-viewer').get('id')

        # Subsets are global and are not linked to specific viewer instances,
        # so it's not required that we match any specific ids for that case.
        # However, if the msg is not none, check to make sure that it's the
        # viewer we care about.
        if msg is not None and msg.viewer_id != self._viewer_id:
            return

        viewer = self.app.get_viewer('spectrum-viewer')

        self.dc_items = [
            layer_state.layer.label for layer_state in viewer.state.layers
        ]

    def vue_data_selected(self, event):
        """
        Callback method for when the user has selected data from the drop down
        in the front-end. It is here that we actually parse and create a new
        data object from the selected data. From this data object, unit
        information is scraped, and the selected spectrum is stored for later
        use in fitting.

        Parameters
        ----------
        event : str
            IPyWidget callback event object. In this case, represents the data
            label of the data collection object selected by the user.
        """
        selected_spec = self.app.get_data_from_viewer("spectrum-viewer",
                                                      data_label=event)

        if self._units == {}:
            self._units["x"] = str(selected_spec.spectral_axis.unit)
            self._units["y"] = str(selected_spec.flux.unit)

        for label in self.dc_items:
            if label in self.data_collection:
                self._label_to_link = label
                break

        self._spectrum1d = selected_spec

        self._run_functions()

    def _run_functions(self, *args, **kwargs):
        """
        Run fitting on the initialized models, fixing any parameters marked
        as such by the user, then update the displauyed parameters with fit
        values
        """
        temp_results = []
        for function in FUNCTIONS:
            # Centroid function requires a region argument, create one to pass
            if function == "Centroid":
                spectral_axis = self._spectrum1d.spectral_axis
                if self._spectrum1d.mask is None:
                    spec_region = SpectralRegion(spectral_axis[0],
                                                 spectral_axis[-1])
                else:
                    spec_region = self._spectrum1d.spectral_axis[np.where(
                        self._spectrum1d.mask == False)]
                    spec_region = SpectralRegion(spec_region[0],
                                                 spec_region[-1])
                temp_result = FUNCTIONS[function](self._spectrum1d,
                                                  spec_region)
            else:
                temp_result = FUNCTIONS[function](self._spectrum1d)

            temp_results.append({
                'function': function,
                'result': str(temp_result)
            })
            self.result_available = True

            self.results = []
            self.results = temp_results
示例#9
0
class MomentMap(TemplateMixin):
    template = load_template("moment_maps.vue", __file__).tag(sync=True)
    n_moment = Any().tag(sync=True)
    dc_items = List([]).tag(sync=True)
    selected_data = Unicode().tag(sync=True)

    filename = Unicode().tag(sync=True)

    moment_available = Bool(False).tag(sync=True)
    spectral_min = Any().tag(sync=True)
    spectral_max = Any().tag(sync=True)
    spectral_unit = Unicode().tag(sync=True)
    spectral_subset_items = List(["None"]).tag(sync=True)
    selected_subset = Unicode("None").tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.hub.subscribe(self,
                           DataCollectionAddMessage,
                           handler=self._on_data_updated)
        self.hub.subscribe(self,
                           DataCollectionDeleteMessage,
                           handler=self._on_data_updated)
        #self.hub.subscribe(self, SubsetCreateMessage,
        #                   handler=self._on_subset_created)
        self._selected_data = None
        self.n_moment = 0
        self.moment = None
        self._filename = None
        self.spectral_min = 0.0
        self.spectral_max = 0.0
        self._spectral_subsets = {}

    def _on_data_updated(self, msg):
        self.dc_items = [x.label for x in self.data_collection]
        # Default to selecting the first loaded cube
        if self._selected_data is None:
            for i in range(len(self.dc_items)):
                # Also set the spectral min and max to default to the full range
                try:
                    self.selected_data = self.dc_items[i]
                    cube = self._selected_data.get_object(cls=SpectralCube)
                    self.spectral_min = cube.spectral_axis[0].value
                    self.spectral_max = cube.spectral_axis[-1].value
                    self.spectral_unit = str(cube.spectral_axis.unit)
                    break
                # Skip data that can't be returned as a SpectralCube
                except (ValueError, TypeError):
                    continue

    def _on_subset_created(self, msg):
        """Currently unimplemented due to problems with the SubsetCreateMessafe"""
        raise ValueError(msg)

    @observe("selected_data")
    def _on_data_selected(self, event):
        self._selected_data = next(
            (x for x in self.data_collection if x.label == event['new']))
        cube = self._selected_data.get_object(cls=SpectralCube)
        # Update spectral bounds and unit if we've switched to another unit
        if str(cube.spectral_axis.unit) != self.spectral_unit:
            self.spectral_min = cube.spectral_axis[0].value
            self.spectral_max = cube.spectral_axis[-1].value
            self.spectral_unit = str(cube.spectral_axis.unit)

    @observe("selected_subset")
    def _on_subset_selected(self, event):
        # If "None" selected, reset based on bounds of selected data
        self._selected_subset = self.selected_subset
        if self._selected_subset == "None":
            cube = self._selected_data.get_object(cls=SpectralCube)
            self.spectral_min = cube.spectral_axis[0].value
            self.spectral_max = cube.spectral_axis[-1].value
        else:
            spec_sub = self._spectral_subsets[self._selected_subset]
            unit = u.Unit(self.spectral_unit)
            spec_reg = SpectralRegion.from_center(spec_sub.center.x * unit,
                                                  spec_sub.width * unit)
            self.spectral_min = spec_reg.lower.value
            self.spectral_max = spec_reg.upper.value

    @observe("filename")
    def _on_filename_changed(self, event):
        self._filename = self.filename

    def vue_list_subsets(self, event):
        """Populate the spectral subset selection dropdown"""
        temp_subsets = self.app.get_subsets_from_viewer("spectrum-viewer")
        temp_list = ["None"]
        temp_dict = {}
        # Attempt to filter out spatial subsets
        for key, region in temp_subsets.items():
            if type(region) == RectanglePixelRegion:
                temp_dict[key] = region
                temp_list.append(key)
        self._spectral_subsets = temp_dict
        self.spectral_subset_items = temp_list

    def vue_calculate_moment(self, event):
        #Retrieve the data cube and slice out desired region, if specified
        cube = self._selected_data.get_object(cls=SpectralCube)
        spec_min = float(self.spectral_min) * u.Unit(self.spectral_unit)
        spec_max = float(self.spectral_max) * u.Unit(self.spectral_unit)
        slab = cube.spectral_slab(spec_min, spec_max)

        # Calculate the moment and convert to CCDData to add to the viewers
        try:
            n_moment = int(self.n_moment)
            if n_moment < 0:
                raise ValueError("Moment must be a positive integer")
        except ValueError:
            raise ValueError("Moment must be a positive integer")
        self.moment = slab.moment(n_moment)

        moment_ccd = CCDData(self.moment.array,
                             wcs=self.moment.wcs,
                             unit=self.moment.unit)

        label = "Moment {}: {}".format(n_moment, self._selected_data.label)
        fname_label = self._selected_data.label.replace("[",
                                                        "_").replace("]", "_")
        self.filename = "moment{}_{}.fits".format(n_moment, fname_label)
        self.data_collection[label] = moment_ccd
        self.moment_available = True

        msg = SnackbarMessage("{} added to data collection".format(label),
                              sender=self,
                              color="success")
        self.hub.broadcast(msg)

    def vue_save_as_fits(self, event):
        self.moment.write(self._filename)
        # Let the user know where we saved the file (don't need path if user
        # specified a full filepath
        if re.search("/", self._filename) is None:
            wd = pathlib.Path.cwd()
            full_path = wd / pathlib.Path(self._filename)
        else:
            full_path = self._filename
        msg = SnackbarMessage("Moment map saved to {}".format(str(full_path)),
                              sender=self,
                              color="success")
        self.hub.broadcast(msg)
示例#10
0
文件: collapse.py 项目: eteq/jdaviz
class Collapse(TemplateMixin):
    template = load_template("collapse.vue", __file__).tag(sync=True)
    data_items = List([]).tag(sync=True)
    selected_data_item = Unicode().tag(sync=True)
    axes = List([]).tag(sync=True)
    selected_axis = Int(0).tag(sync=True)
    funcs = List(['Mean', 'Median', 'Min', 'Max', 'Sum']).tag(sync=True)
    selected_func = Unicode('Mean').tag(sync=True)

    spectral_min = Any().tag(sync=True)
    spectral_max = Any().tag(sync=True)
    spectral_unit = Unicode().tag(sync=True)
    spectral_subset_items = List(["None"]).tag(sync=True)
    selected_subset = Unicode("None").tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.hub.subscribe(self,
                           DataCollectionAddMessage,
                           handler=self._on_data_updated)
        self.hub.subscribe(self,
                           DataCollectionDeleteMessage,
                           handler=self._on_data_updated)

        self._selected_data = None
        self._label_counter = 0

    def _on_data_updated(self, msg):
        self.data_items = [x.label for x in self.data_collection]
        # Default to selecting the first loaded cube
        if self._selected_data is None:
            for i in range(len(self.data_items)):
                try:
                    self.selected_data_item = self.data_items[i]
                except (ValueError, TypeError):
                    continue

    @observe('selected_data_item')
    def _on_data_item_selected(self, event):
        self._selected_data = next(
            (x for x in self.data_collection if x.label == event['new']))

        # Also set the spectral min and max to default to the full range
        cube = self._selected_data.get_object(cls=SpectralCube)
        self.spectral_min = cube.spectral_axis[0].value
        self.spectral_max = cube.spectral_axis[-1].value
        self.spectral_unit = str(cube.spectral_axis.unit)

        self.axes = list(range(len(self._selected_data.shape)))

    @observe("selected_subset")
    def _on_subset_selected(self, event):
        # If "None" selected, reset based on bounds of selected data
        self._selected_subset = self.selected_subset
        if self._selected_subset == "None":
            cube = self._selected_data.get_object(cls=SpectralCube)
            self.spectral_min = cube.spectral_axis[0].value
            self.spectral_max = cube.spectral_axis[-1].value
        else:
            spec_sub = self._spectral_subsets[self._selected_subset]
            unit = u.Unit(self.spectral_unit)
            spec_reg = SpectralRegion.from_center(spec_sub.center.x * unit,
                                                  spec_sub.width * unit)
            self.spectral_min = spec_reg.lower.value
            self.spectral_max = spec_reg.upper.value

    def vue_list_subsets(self, event):
        """Populate the spectral subset selection dropdown"""
        temp_subsets = self.app.get_subsets_from_viewer("spectrum-viewer")
        temp_list = ["None"]
        temp_dict = {}
        # Attempt to filter out spatial subsets
        for key, region in temp_subsets.items():
            if type(region) == RectanglePixelRegion:
                temp_dict[key] = region
                temp_list.append(key)
        self._spectral_subsets = temp_dict
        self.spectral_subset_items = temp_list

    def vue_collapse(self, *args, **kwargs):
        try:
            spec = self._selected_data.get_object(cls=SpectralCube)
        except AttributeError:
            snackbar_message = SnackbarMessage(
                f"Unable to perform collapse over selected data.",
                color="error",
                sender=self)
            self.hub.broadcast(snackbar_message)

            return

        # If collapsing over the spectral axis, cut out the desired spectral
        # region. Defaults to the entire spectrum.
        if self.selected_axis == 0:
            spec_min = float(self.spectral_min) * u.Unit(self.spectral_unit)
            spec_max = float(self.spectral_max) * u.Unit(self.spectral_unit)
            spec = spec.spectral_slab(spec_min, spec_max)

        collapsed_spec = getattr(
            spec, self.selected_func.lower())(axis=self.selected_axis)

        data = Data(coords=collapsed_spec.wcs)
        data['flux'] = collapsed_spec.filled_data[...]
        data.get_component('flux').units = str(collapsed_spec.unit)
        data.meta.update(collapsed_spec.meta)

        self._label_counter += 1
        label = f"Collapsed {self._label_counter} {self._selected_data.label}"

        self.data_collection[label] = data

        # Link the new dataset pixel-wise to the original dataset. In general
        # direct pixel to pixel links are the most efficient and should be
        # used in cases like this where we know there is a 1-to-1 mapping of
        # pixel coordinates. Here which axes are linked to which depends on
        # the selected axis.
        (i1, i2), (i1c, i2c) = AXES_MAPPING[self.selected_axis]

        self.data_collection.add_link(
            LinkSame(self._selected_data.pixel_component_ids[i1],
                     self.data_collection[label].pixel_component_ids[i1c]))
        self.data_collection.add_link(
            LinkSame(self._selected_data.pixel_component_ids[i2],
                     self.data_collection[label].pixel_component_ids[i2c]))

        snackbar_message = SnackbarMessage(
            f"Data set '{self._selected_data.label}' collapsed successfully.",
            color="success",
            sender=self)
        self.hub.broadcast(snackbar_message)
示例#11
0
class GaussianSmooth(TemplateMixin):
    template = load_template("gaussian_smooth.vue", __file__).tag(sync=True)
    stddev = Any().tag(sync=True)
    dc_items = List([]).tag(sync=True)
    selected_data = Unicode().tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.hub.subscribe(self,
                           DataCollectionAddMessage,
                           handler=self._on_data_updated)
        self.hub.subscribe(self,
                           DataCollectionDeleteMessage,
                           handler=self._on_data_updated)

        self._selected_data = None

    def _on_data_updated(self, msg):
        self.dc_items = [x.label for x in self.data_collection]

    @observe("selected_data")
    def _on_data_selected(self, event):
        self._selected_data = next(
            (x for x in self.data_collection if x.label == event['new']))

    def vue_gaussian_smooth(self, *args, **kwargs):
        # Testing inputs to make sure putting smoothed spectrum into
        # datacollection works
        # input_flux = Quantity(np.array([0.2, 0.3, 2.2, 0.3]), u.Jy)
        # input_spaxis = Quantity(np.array([1, 2, 3, 4]), u.micron)
        # spec1 = Spectrum1D(input_flux, spectral_axis=input_spaxis)
        size = float(self.stddev)

        try:
            spec = self._selected_data.get_object(cls=Spectrum1D)
        except TypeError:
            snackbar_message = SnackbarMessage(
                f"Unable to perform smoothing over selected data.",
                color="error",
                sender=self)
            self.hub.broadcast(snackbar_message)

            return

        # Takes the user input from the dialog (stddev) and uses it to
        # define a standard deviation for gaussian smoothing
        spec_smoothed = gaussian_smooth(spec, stddev=size)

        label = f"Smoothed {self._selected_data.label}"

        self.data_collection[label] = spec_smoothed

        # Link the new dataset pixel-wise to the original dataset. In general
        # direct pixel to pixel links are the most efficient and should be
        # used in cases like this where we know there is a 1-to-1 mapping of
        # pixel coordinates. Here the smoothing returns a 1-d spectral object
        # which we can link to the first dimension of the original dataset
        # (whcih could in principle be a cube or a spectrum)
        self.data_collection.add_link(
            LinkSame(self._selected_data.pixel_component_ids[0],
                     self.data_collection[label].pixel_component_ids[0]))

        snackbar_message = SnackbarMessage(
            f"Data set '{self._selected_data.label}' smoothed successfully.",
            color="success",
            sender=self)
        self.hub.broadcast(snackbar_message)
示例#12
0
class GaussianSmooth(TemplateMixin):
    template = load_template("gaussian_smooth.vue", __file__).tag(sync=True)
    stddev = Any().tag(sync=True)
    dc_items = List([]).tag(sync=True)
    selected_data = Unicode().tag(sync=True)
    show_modes = Bool(False).tag(sync=True)
    smooth_modes = List(["Spectral", "Spatial"]).tag(sync=True)
    selected_mode = Unicode("Spectral").tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.hub.subscribe(self,
                           DataCollectionAddMessage,
                           handler=self._on_data_updated)
        self.hub.subscribe(self,
                           DataCollectionDeleteMessage,
                           handler=self._on_data_updated)

        self._selected_data = None
        self._config = self.app.state.settings.get("configuration")
        if self._config == "cubeviz":
            self.show_modes = True

    def _on_data_updated(self, msg):
        self.dc_items = [x.label for x in self.data_collection]

    @observe("selected_data")
    def _on_data_selected(self, event):
        self._selected_data = next(
            (x for x in self.data_collection if x.label == event['new']))

    def vue_spectral_smooth(self, *args, **kwargs):
        # Testing inputs to make sure putting smoothed spectrum into
        # datacollection works
        # input_flux = Quantity(np.array([0.2, 0.3, 2.2, 0.3]), u.Jy)
        # input_spaxis = Quantity(np.array([1, 2, 3, 4]), u.micron)
        # spec1 = Spectrum1D(input_flux, spectral_axis=input_spaxis)
        size = float(self.stddev)

        try:
            spec = self._selected_data.get_object(cls=Spectrum1D)
        except TypeError:
            snackbar_message = SnackbarMessage(
                "Unable to perform smoothing over selected data.",
                color="error",
                sender=self)
            self.hub.broadcast(snackbar_message)

            return

        # Takes the user input from the dialog (stddev) and uses it to
        # define a standard deviation for gaussian smoothing
        spec_smoothed = gaussian_smooth(spec, stddev=size)

        label = f"Smoothed {self._selected_data.label} stddev {size}"

        if label in self.data_collection:
            snackbar_message = SnackbarMessage(
                "Data with selected stddev already exists, canceling operation.",
                color="error",
                sender=self)
            self.hub.broadcast(snackbar_message)

            return

        self.data_collection[label] = spec_smoothed

        snackbar_message = SnackbarMessage(
            f"Data set '{self._selected_data.label}' smoothed successfully.",
            color="success",
            sender=self)
        self.hub.broadcast(snackbar_message)

    def vue_spatial_convolution(self, *args):
        """
        Use astropy convolution machinery to smooth the spatial dimensions of
        the data cube.
        """

        size = float(self.stddev)
        cube = self._selected_data.get_object(cls=SpectralCube)
        # Extend the 2D kernel to have a length 1 spectral dimension, so that
        # we can do "3d" convolution to the whole cube
        kernel = np.expand_dims(Gaussian2DKernel(size), 0)

        # TODO: in vuetify >2.3, timeout should be set to -1 to keep open
        #  indefinitely
        snackbar_message = SnackbarMessage(
            "Smoothing spatial slices of cube...",
            loading=True,
            timeout=0,
            sender=self)
        self.hub.broadcast(snackbar_message)

        convolved_data = convolve(cube.hdu.data, kernel)
        # Create a new cube with the old metadata. Note that astropy
        # convolution generates values for masked (NaN) data, but we keep the
        # original mask here.
        newcube = SpectralCube(data=convolved_data,
                               wcs=cube.wcs,
                               mask=cube.mask,
                               meta=cube.meta,
                               fill_value=cube.fill_value)

        label = f"Smoothed {self._selected_data.label} spatial stddev {size}"

        if label in self.data_collection:
            snackbar_message = SnackbarMessage(
                "Data with selected stddev already exists, canceling operation.",
                color="error",
                sender=self)
            self.hub.broadcast(snackbar_message)

            return

        self.data_collection[label] = newcube

        snackbar_message = SnackbarMessage(
            f"Data set '{self._selected_data.label}' smoothed successfully.",
            color="success",
            sender=self)
        self.hub.broadcast(snackbar_message)
示例#13
0
class RedshiftSlider(TemplateMixin):
    template = load_template("redshift_slider.vue", __file__).tag(sync=True)
    slider = Any(0).tag(sync=True)
    slider_textbox = Any(0).tag(sync=True)
    slider_type = Any("Redshift").tag(sync=True)
    min_value = Float(0).tag(sync=True)
    max_value = Float(0.1).tag(sync=True)
    slider_step = Float(0.00001).tag(sync=True)
    linked = Bool(True).tag(sync=True)
    wait = Int(100).tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._watched_viewers = []

        # Watch for new data to grab its redshift if it has one
        self.session.hub.subscribe(self,
                                   AddDataMessage,
                                   handler=self._on_data_added)
        # Watch for messages from Specviz helper redshift functions
        self.session.hub.subscribe(self,
                                   RedshiftMessage,
                                   handler=self._parse_redshift_msg)

        self._update_bounds = {
            "Redshift": self._update_bounds_redshift,
            "RV (km/s)": self._update_bounds_rv
        }

    def _on_data_added(self, msg):
        if isinstance(msg.viewer, BqplotProfileView):
            label = msg.data.label
            temp_data = self.app.get_data_from_viewer("spectrum-viewer")[label]
            if self.slider_type == "Redshift":
                new_z = temp_data.redshift.value
                if new_z < self.min_value or new_z > self.max_value:
                    self._update_bounds_redshift(new_z)
                self.slider = new_z
            else:
                new_rv = temp_data.radial_velocity.to("km/s").value
                if new_rv < self.min_value or new_rv > self.max_value:
                    self._update_bounds_rv(new_rv)
                self.slider = new_rv

    def _parse_redshift_msg(self, msg):
        '''
        Handle incoming redshift messages from the app hub. Generally these
        will be created by Specviz helper methods.
        '''
        if msg.sender == self:
            return

        param = msg.param
        val = float(msg.value)

        if param == "slider_min":
            self.min_value = val
        elif param == "slider_max":
            self.max_value = val
        elif param == "slider_step":
            self.slider_step = val
        elif param == "redshift":
            if val > self.max_value or val < self.min_value:
                self._update_bounds[self.slider_type](val)
            self.slider = val

    def _velocity_to_redshift(self, velocity):
        """
        Convert a velocity to a relativistic redshift.
        """
        beta = velocity / c
        return np.sqrt((1 + beta) / (1 - beta)) - 1

    def _redshift_to_velocity(self, redshift):
        """
        Convert a relativistic redshift to a velocity.
        """
        zponesq = (1 + redshift)**2
        return c * (zponesq - 1) / (zponesq + 1)

    def _propagate_redshift(self):
        """
        When the redshift is changed with the slider, send the new value to
        the line list and spectrum viewer data.
        """
        if self.slider == "" or self.slider == "-":
            return
        if self.slider_type == "Redshift":
            z = u.Quantity(self.slider)
        else:
            z = self._velocity_to_redshift(u.Quantity(self.slider, "km/s"))

        line_list = self.app.get_viewer('spectrum-viewer').spectral_lines
        if line_list is not None:
            line_list["redshift"] = z
            # Replot with the new redshift
            line_list = self.app.get_viewer(
                'spectrum-viewer').plot_spectral_lines()

        # Send the redshift back to the Specviz helper
        msg = RedshiftMessage("redshift", z.value, sender=self)
        self.app.hub.broadcast(msg)
        '''
        for data_item in self.app.data_collection:
            if type(data_item.coords.spectral_axis) == SpectralAxis:
                if self.slider_type == "Redshift":
                    new_axis = SpectralAxis(data_item.coords.spectral_axis,
                                            redshift = self.slider)
                else:
                    new_axis = SpectralAxis(data_item.coords.spectral_axis,
                                radial_velocity = u.Quantity(self.slider, "km/s"))
                data_item.coords = SpectralCoordinates(new_axis)
        '''

    # def _slider_value_updated(self, value):
    #     if len(value) > 0:
    #         self.slider = float(value[0])

    def _set_bounds_orderly(self, new_min, new_max, new_val):
        '''Have to do this in the right order so our slider value is never out of bounds'''
        if new_val > self.max_value:
            self.max_value = new_max
            self.slider = new_val
            self.min_value = new_min
        elif new_val < self.min_value:
            self.min_value = new_min
            self.slider = new_val
            self.max_value = new_max
        else:
            self.min_value = new_min
            self.max_value = new_max

    def _update_bounds_redshift(self, new_val):
        '''Set reasonable slider parameters based on manually set redshift'''
        if new_val >= 0 and new_val - 0.5 < 0:
            new_min = 0
        else:
            new_min = new_val - 0.5
        new_max = new_val + 0.5

        self._set_bounds_orderly(new_min, new_max, new_val)

        self.slider_step = 0.001

    def _update_bounds_rv(self, new_val):
        '''Set reasonable slider parameters based on manually set radial velocity'''
        if new_val >= 0 and new_val < 100000:
            new_min = 0
            new_max = new_val + 100000
            step = 500
        elif new_val < 0 and new_val > -100000:
            new_min = new_val - 100000
            new_max = 0
            step = 500
        else:
            new_min = new_val - (new_val / 100.0)
            new_max = new_val + (new_val / 100.0)
            step = new_val / 10000.0

        self._set_bounds_orderly(new_min, new_max, new_val)

        self.slider_step = step

    @observe('slider_textbox')
    def _on_textbox_change(self, event):
        try:
            val = float(event["new"])
        except ValueError:
            return

        if val > self.max_value or val < self.min_value:
            self._update_bounds[self.slider_type](val)

        if self.slider != val:
            self.slider = val

    @observe('slider')
    def _on_slider_updated(self, event):
        if not event['new']:
            value = 0
        else:
            value = float(event['new'])

        if value > self.max_value or value < self.min_value:
            self._update_bounds[self.slider_type](value)
            self.slider = value
        else:
            self.slider = value

        if self.slider != float(self.slider_textbox):
            self.slider_textbox = self.slider

        self._propagate_redshift()

    @observe('slider_type')
    def _on_type_updated(self, event):
        if event['new'] == "Redshift":
            new_val = self._velocity_to_redshift(
                u.Quantity(self.slider, "km/s")).value
            self._update_bounds_redshift(new_val)
            self.slider = new_val
        else:
            new_val = self._redshift_to_velocity(self.slider).to('km/s').value
            self._update_bounds_rv(new_val)
            self.slider = new_val
示例#14
0
class Collapse(TemplateMixin):
    template = load_template("collapse.vue", __file__).tag(sync=True)
    data_items = List([]).tag(sync=True)
    selected_data_item = Unicode().tag(sync=True)
    axes = List([]).tag(sync=True)
    selected_axis = Int(0).tag(sync=True)
    funcs = List(['Mean', 'Median', 'Min', 'Max']).tag(sync=True)
    selected_func = Unicode('Mean').tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.hub.subscribe(self, DataCollectionAddMessage,
                           handler=self._on_data_updated)
        self.hub.subscribe(self, DataCollectionDeleteMessage,
                           handler=self._on_data_updated)

        self._selected_data = None

    def _on_data_updated(self, msg):
        self.data_items = [x.label for x in self.data_collection]

    @observe('selected_data_item')
    def _on_data_item_selected(self, event):
        self._selected_data = next((x for x in self.data_collection
                                    if x.label == event['new']))

        self.axes = list(range(len(self._selected_data.shape)))

    def vue_collapse(self, *args, **kwargs):
        try:
            spec = self._selected_data.get_object(cls=SpectralCube)
        except AttributeError:
            snackbar_message = SnackbarMessage(
                f"Unable to perform collapse over selected data.",
                color="error",
                sender=self)
            self.hub.broadcast(snackbar_message)

            return

        collapsed_spec = getattr(spec, self.selected_func.lower())(
            axis=self.selected_axis)

        data = Data(coords=collapsed_spec.wcs)
        data['flux'] = collapsed_spec.filled_data[...]
        data.get_component('flux').units = str(collapsed_spec.unit)
        data.meta.update(collapsed_spec.meta)

        label = f"Collapsed {self._selected_data.label}"

        self.data_collection[label] = data

        # Link the new dataset pixel-wise to the original dataset. In general
        # direct pixel to pixel links are the most efficient and should be
        # used in cases like this where we know there is a 1-to-1 mapping of
        # pixel coordinates. Here which axes are linked to which depends on
        # the selected axis.
        (i1, i2), (i1c, i2c) = AXES_MAPPING[self.selected_axis]

        self.data_collection.add_link(LinkSame(self._selected_data.pixel_component_ids[i1],
                                               self.data_collection[label].pixel_component_ids[i1c]))
        self.data_collection.add_link(LinkSame(self._selected_data.pixel_component_ids[i2],
                                               self.data_collection[label].pixel_component_ids[i2c]))

        snackbar_message = SnackbarMessage(
            f"Data set '{self._selected_data.label}' collapsed successfully.",
            color="success",
            sender=self)
        self.hub.broadcast(snackbar_message)
示例#15
0
class ModelFitting(TemplateMixin):
    dialog = Bool(False).tag(sync=True)
    template = load_template("model_fitting.vue", __file__).tag(sync=True)
    dc_items = List([]).tag(sync=True)

    save_enabled = Bool(False).tag(sync=True)
    model_label = Unicode().tag(sync=True)
    model_save_path = Unicode().tag(sync=True)
    temp_name = Unicode().tag(sync=True)
    temp_model = Unicode().tag(sync=True)
    model_equation = Unicode().tag(sync=True)
    eq_error = Bool(False).tag(sync=True)
    component_models = List([]).tag(sync=True)
    display_order = Bool(False).tag(sync=True)
    poly_order = Int(0).tag(sync=True)

    available_models = List(list(MODELS.keys())).tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._viewer_spectra = None
        self._spectrum1d = None
        self._units = {}
        self.n_models = 0
        self._fitted_model = None
        self._fitted_spectrum = None
        self.component_models = []
        self._initialized_models = {}
        self._display_order = False
        self._label_to_link = ""
        self.model_save_path = os.getcwd()
        self.model_label = "Model"

        self.hub.subscribe(self, AddDataMessage,
                           handler=self._on_viewer_data_changed)

        self.hub.subscribe(self, RemoveDataMessage,
                           handler=self._on_viewer_data_changed)

        self.hub.subscribe(self, SubsetCreateMessage,
                           handler=lambda x: self._on_viewer_data_changed())

        self.hub.subscribe(self, SubsetDeleteMessage,
                           handler=lambda x: self._on_viewer_data_changed())

        self.hub.subscribe(self, SubsetUpdateMessage,
                           handler=lambda x: self._on_viewer_data_changed())

    def _on_viewer_data_changed(self, msg=None):
        """
        Callback method for when data is added or removed from a viewer, or
        when a subset is created, deleted, or updated. This method receieves
        a glue message containing viewer information in the case of the former
        set of events, and updates the available data list displayed to the
        user.

        Notes
        -----
        We do not attempt to parse any data at this point, at it can cause
        visible lag in the application.

        Parameters
        ----------
        msg : `glue.core.Message`
            The glue message passed to this callback method.
        """
        self._viewer_id = self.app._viewer_item_by_reference(
            'spectrum-viewer').get('id')

        # Subsets are global and are not linked to specific viewer instances,
        # so it's not required that we match any specific ids for that case.
        # However, if the msg is not none, check to make sure that it's the
        # viewer we care about.
        if msg is not None and msg.viewer_id != self._viewer_id:
            return

        viewer = self.app.get_viewer('spectrum-viewer')

        self.dc_items = [layer_state.layer.label
                         for layer_state in viewer.state.layers]

    def _param_units(self, param, order = 0):
        """Helper function to handle units that depend on x and y"""
        y_params = ["amplitude", "amplitude_L", "intercept"]

        if param == "slope":
            return str(u.Unit(self._units["y"]) / u.Unit(self._units["x"]))
        elif param == "poly":
            return str(u.Unit(self._units["y"]) / u.Unit(self._units["x"])**order)

        return self._units["y"] if param in y_params else self._units["x"]

    def _update_parameters_from_fit(self):
        """Insert the results of the model fit into the component_models"""
        for m in self.component_models:
            name = m["id"]
            if len(self.component_models) > 1:
                m_fit = self._fitted_model.unitless_model[name]
            else:
                m_fit = self._fitted_model
            temp_params = []
            for i in range(0, len(m_fit.parameters)):
                temp_param = [x for x in m["parameters"] if x["name"] ==
                              m_fit.param_names[i]]
                temp_param[0]["value"] = m_fit.parameters[i]
                temp_params += temp_param
            m["parameters"] = temp_params
        # Trick traitlets into updating the displayed values
        component_models = self.component_models
        self.component_models = []
        self.component_models = component_models

    def _update_initialized_parameters(self):
        # If the user changes a parameter value, we need to change it in the
        # initialized model
        for m in self.component_models:
            name = m["id"]
            for param in m["parameters"]:
                quant_param = u.Quantity(param["value"], param["unit"])
                setattr(self._initialized_models[name], param["name"],
                        quant_param)

    def vue_data_selected(self, event):
        """
        Callback method for when the user has selected data from the drop down
        in the front-end. It is here that we actually parse and create a new
        data object from the selected data. From this data object, unit
        information is scraped, and the selected spectrum is stored for later
        use in fitting.

        Parameters
        ----------
        event : str
            IPyWidget callback event object. In this case, represents the data
            label of the data collection object selected by the user.
        """
        selected_spec = self.app.get_data_from_viewer("spectrum-viewer",
                                                      data_label=event)

        if self._units == {}:
            self._units["x"] = str(
                selected_spec.spectral_axis.unit)
            self._units["y"] = str(
                selected_spec.flux.unit)

        for label in self.dc_items:
            if label in self.data_collection:
                self._label_to_link = label
                break

        self._spectrum1d = selected_spec

    def vue_model_selected(self, event):
        # Add the model selected to the list of models
        self.temp_model = event
        if event == "Polynomial1D":
            self.display_order = True
        else:
            self.display_order = False

    def _initialize_polynomial(self, new_model):
        initialized_model = initialize(
            MODELS[self.temp_model](name=self.temp_name, degree=self.poly_order),
            self._spectrum1d.spectral_axis,
            self._spectrum1d.flux)

        self._initialized_models[self.temp_name] = initialized_model

        for i in range(self.poly_order + 1):
            param = "c{}".format(i)
            initial_val = getattr(initialized_model, param).value
            new_model["parameters"].append({"name": param,
                                            "value": initial_val,
                                            "unit": self._param_units("poly", i),
                                            "fixed": False})
        return new_model

    def vue_add_model(self, event):
        """Add the selected model and input string ID to the list of models"""
        new_model = {"id": self.temp_name, "model_type": self.temp_model,
                     "parameters": []}

        # Need to do things differently for polynomials, since the order varies
        if self.temp_model == "Polynomial1D":
            new_model = self._initialize_polynomial(new_model)
        else:
            # Have a separate private dict with the initialized models, since
            # they don't play well with JSON for widget interaction
            initialized_model = initialize(
                MODELS[self.temp_model](name=self.temp_name),
                self._spectrum1d.spectral_axis,
                self._spectrum1d.flux)

            self._initialized_models[self.temp_name] = initialized_model

            for param in model_parameters[new_model["model_type"]]:
                initial_val = getattr(initialized_model, param).value
                new_model["parameters"].append({"name": param,
                                                "value": initial_val,
                                                "unit": self._param_units(param),
                                                "fixed": False})

        new_model["Initialized"] = True
        self.component_models = self.component_models + [new_model]

    def vue_remove_model(self, event):
        self.component_models = [x for x in self.component_models
                                 if x["id"] != event]
        del(self._initialized_models[event])

    def vue_save_model(self, event):
        if self.model_save_path[-1] == "/":
            connector = ""
        else:
            connector = "/"
        full_path = self.model_save_path + connector + self.model_label + ".pkl"
        with open(full_path, 'wb') as f:
            pickle.dump(self._fitted_model, f)

    def vue_equation_changed(self, event):
        # Length is a dummy check to test the infrastructure
        if len(self.model_equation) > 20:
            self.eq_error = True

    def vue_model_fitting(self, *args, **kwargs):
        """
        Run fitting on the initialized models, fixing any parameters marked
        as such by the user, then update the displauyed parameters with fit
        values
        """
        fitted_model, fitted_spectrum = fit_model_to_spectrum(
            self._spectrum1d,
            self._initialized_models.values(),
            self.model_equation,
            run_fitter=True)
        self._fitted_model = fitted_model
        self._fitted_spectrum = fitted_spectrum

        # Update component model parameters with fitted values
        self._update_parameters_from_fit()

        self.save_enabled = True

    def vue_register_spectrum(self, event):
        """
        Add a spectrum to the data collection based on the currently displayed
        parameters (these could be user input or fit values).
        """
        # Make sure the initialized models are updated with any user-specified
        # parameters
        self._update_initialized_parameters()

        # Need to run the model fitter with run_fitter=False to get spectrum
        model, spectrum = fit_model_to_spectrum(self._spectrum1d,
                                                self._initialized_models.values(),
                                                self.model_equation)

        self.n_models += 1
        label = self.model_label
        if label in self.data_collection:
            self.app.remove_data_from_viewer('spectrum-viewer', label)
            # Some hacky code to remove the label from the data dropdown
            temp_items = []
            for data_item in self.app.state.data_items:
                if data_item['name'] != label:
                    temp_items.append(data_item)
            self.app.state.data_items = temp_items
            # Remove the actual Glue data object from the data_collection
            self.data_collection.remove(self.data_collection[label])
        self.data_collection[label] = spectrum
        self.save_enabled = True
        self.data_collection.add_link(
            LinkSame(self.data_collection[self._label_to_link].pixel_component_ids[0],
                     self.data_collection[label].pixel_component_ids[0]))
示例#16
0
class CoordsInfo(TemplateMixin):
    template = load_template("coords_info.vue", __file__).tag(sync=True)
    pixel = Unicode("").tag(sync=True)
    world = Unicode("").tag(sync=True)
    value = Unicode("").tag(sync=True)
示例#17
0
class UnitConversion(TemplateMixin):
    template = load_template("unit_conversion.vue", __file__).tag(sync=True)
    dc_items = List([]).tag(sync=True)
    selected_data = Unicode().tag(sync=True)
    current_flux_unit = Unicode().tag(sync=True)
    current_spectral_axis_unit = Unicode().tag(sync=True)
    new_flux_unit = Any().tag(sync=True)
    new_spectral_axis_unit = Any().tag(sync=True)
    spectral_axis_unit_equivalencies = List([]).tag(sync=True)
    flux_unit_equivalencies = List([]).tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._viewer_data = None

        self.spectrum = None

        self.hub.subscribe(self, AddDataMessage,
                           handler=self._on_viewer_data_changed)

        self.hub.subscribe(self, RemoveDataMessage,
                           handler=self._on_viewer_data_changed)


    def _on_viewer_data_changed(self, msg=None):
        """
        Callback method for when data is added or removed from a viewer, or
        when a subset is created, deleted, or updated. This method receieves
        a glue message containing viewer information in the case of the former
        set of events, and updates the available data list displayed to the
        user.

        Notes
        -----
        We do not attempt to parse any data at this point, at it can cause
        visible lag in the application.

        Parameters
        ----------
        msg : `glue.core.Message`
            The glue message passed to this callback method.
        """
        self._viewer_id = self.app._viewer_item_by_reference(
            'spectrum-viewer').get('id')

        # Subsets are global and are not linked to specific viewer instances,
        # so it's not required that we match any specific ids for that case.
        # However, if the msg is not none, check to make sure that it's the
        # viewer we care about.
        if msg is not None and msg.viewer_id != self._viewer_id:
            return

        viewer = self.app.get_viewer('spectrum-viewer')

        self._viewer_data = self.app.get_data_from_viewer('spectrum-viewer')

        self.dc_items = [layer_state.layer.label
                         for layer_state in viewer.state.layers]

        self.update_ui()

    def update_ui(self):
        """
        Set up UI to have all values of currently visible spectra.
        """
        if len(self.dc_items) < 1:
            self.selected_data = ""
            self.current_flux_unit = ""
            self.current_spectral_axis_unit = ""
            return

        self.selected_data = self.app.get_viewer("spectrum-viewer").state.reference_data.label

        self.spectrum = self._viewer_data[self.selected_data]

        # Set UI label to show current flux and spectral axis units.
        self.current_flux_unit = self.spectrum.flux.unit.to_string()
        self.current_spectral_axis_unit = self.spectrum.spectral_axis.unit.to_string()

        # Populate drop down with all valid options for unit conversion.
        self.spectral_axis_unit_equivalencies = self.create_spectral_equivalencies_list()
        self.flux_unit_equivalencies = self.create_flux_equivalencies_list()

    def vue_unit_conversion(self, *args, **kwargs):
        """
        Runs when the `apply` button is hit. Tries to change units if `new` units are set
        and are valid.
        """

        set_spectral_axis_unit = self.spectrum.spectral_axis
        set_flux_unit = self.spectrum.flux

        # Try to set new units if set and are valid.
        if self.new_spectral_axis_unit is not None \
                and self.new_spectral_axis_unit != "" \
                and self.new_spectral_axis_unit != self.current_spectral_axis_unit:
            try:
                set_spectral_axis_unit = self.spectrum.spectral_axis.to(u.Unit(self.new_spectral_axis_unit))
            except ValueError:
                snackbar_message = SnackbarMessage(
                    f"Unable to convert spectral axis units for selected data. Try different units.",
                    color="error",
                    sender=self)
                self.hub.broadcast(snackbar_message)

                return

        # Try to set new units if set and are valid.
        if self.new_flux_unit is not None \
                and self.new_flux_unit != "" \
                and self.new_flux_unit != self.current_flux_unit:
            try:

                set_flux_unit = self.spectrum.flux.to(u.Unit(self.new_flux_unit),
                                                      equivalencies=u.spectral_density(set_spectral_axis_unit))
            except ValueError:
                snackbar_message = SnackbarMessage(
                    f"Unable to convert flux units for selected data. Try different units.",
                    color="error",
                    sender=self)
                self.hub.broadcast(snackbar_message)

                return

        # Uncertainty converted to new flux units
        if self.spectrum.uncertainty is not None:
            unit_exp = unit_exponents.get(self.spectrum.uncertainty.__class__)
            # If uncertainty type not in our lookup, drop the uncertainty
            if unit_exp is None:
                msg = SnackbarMessage(
                    "Warning: Unrecognized uncertainty type, cannot guarantee conversion so dropping uncertainty in resulting data",
                    color="warning",
                    sender=self)
                self.hub.broadcast(msg)
                temp_uncertainty = None
            else:
                try:
                    # Catch and handle error trying to convert variance uncertainties
                    # between frequency and wavelength space.
                    # TODO: simplify this when astropy handles it
                    temp_uncertainty = self.spectrum.uncertainty.quantity**(1/unit_exp)
                    temp_uncertainty = temp_uncertainty.to(u.Unit(set_flux_unit.unit),
                                    equivalencies=u.spectral_density(set_spectral_axis_unit))
                    temp_uncertainty **= unit_exp
                    temp_uncertainty = self.spectrum.uncertainty.__class__(temp_uncertainty.value)
                except u.UnitConversionError:
                    msg = SnackbarMessage(
                        "Warning: Could not convert uncertainty, setting to None in converted data",
                        color="warning",
                        sender=self)
                    self.hub.broadcast(msg)
                    temp_uncertainty = None
        else:
            temp_uncertainty = None

        # Create new spectrum with new units.
        converted_spec = self.spectrum._copy(flux=set_flux_unit,
                                             spectral_axis=set_spectral_axis_unit,
                                             unit=set_flux_unit.unit,
                                             uncertainty=temp_uncertainty
                                             )

        # Finds the '_units_copy_' spectrum and does unit conversions in that copy.
        if "_units_copy_" in self.selected_data:
            selected_data_label = self.selected_data
            selected_data_label_split = selected_data_label.split("_units_copy_")
            label = selected_data_label_split[0] + "_units_copy_" + datetime.datetime.now().isoformat()

            # Removes the old version of the unit conversion copy and creates
            # a new version with the most recent conversion.
            if selected_data_label in self.data_collection:
                self.app.remove_data_from_viewer('spectrum-viewer', selected_data_label)

                # Remove the actual Glue data object from the data_collection
                self.data_collection.remove(self.data_collection[selected_data_label])
            self.data_collection[selected_data_label] = converted_spec

            #TODO: Fix bug that sends AddDataMessage into a loop
            self.app.add_data_to_viewer("spectrum-viewer", selected_data_label)

        else:
            label = self.selected_data + "_units_copy_" + datetime.datetime.now().isoformat()

            # Replace old spectrum with new one with updated units.
            self.app.add_data(converted_spec, label)
            self.app.add_data_to_viewer("spectrum-viewer", label, clear_other_data=True)

        # Reset UI labels.
        self.new_flux_unit = ""
        self.new_spectral_axis_unit = ""

        snackbar_message = SnackbarMessage(
            f"Data set '{label}' units converted successfully.",
            color="success",
            sender=self)
        self.hub.broadcast(snackbar_message)

    def create_spectral_equivalencies_list(self):
        """
        Gets all possible conversions from current spectral_axis_unit.
        """
        # Get unit equivalencies.
        curr_spectral_axis_unit_equivalencies = u.Unit(
            self.spectrum.spectral_axis.unit).find_equivalent_units(
            equivalencies=u.spectral())

        # Get local units.
        local_units = [u.Unit(unit) for unit in self._locally_defined_spectral_axis_units()]

        # Remove overlap units.
        curr_spectral_axis_unit_equivalencies = list(set(curr_spectral_axis_unit_equivalencies)
                                                     - set(local_units))

        # Convert equivalencies into readable versions of the units and sorted alphabetically.
        spectral_axis_unit_equivalencies_titles = sorted(self.convert_units_to_strings(
            curr_spectral_axis_unit_equivalencies))

        # Concatenate both lists with the local units coming first.
        spectral_axis_unit_equivalencies_titles = sorted(self.convert_units_to_strings(local_units)) \
                                                  + spectral_axis_unit_equivalencies_titles

        return spectral_axis_unit_equivalencies_titles

    def create_flux_equivalencies_list(self):
        """
        Gets all possible conversions for flux from current flux units.
        """
        # Get unit equivalencies.
        curr_flux_unit_equivalencies = u.Unit(
            self.spectrum.flux.unit).find_equivalent_units(
            equivalencies=u.spectral_density(np.sum(self.spectrum.spectral_axis)), include_prefix_units=False)

        # Get local units.
        local_units = [u.Unit(unit) for unit in self._locally_defined_flux_units()]

        # Remove overlap units.
        curr_flux_unit_equivalencies = list(set(curr_flux_unit_equivalencies)
                                                     - set(local_units))

        # Convert equivalencies into readable versions of the units and sort them alphabetically.
        flux_unit_equivalencies_titles = sorted(self.convert_units_to_strings(curr_flux_unit_equivalencies))

        # Concatenate both lists with the local units coming first.
        flux_unit_equivalencies_titles = sorted(self.convert_units_to_strings(local_units)) + \
                                         flux_unit_equivalencies_titles

        return flux_unit_equivalencies_titles

    @staticmethod
    def _locally_defined_flux_units():
        """
        list of defined spectral flux density units.
        """
        units = ['Jy', 'mJy', 'uJy',
                 'W / (m2 Hz)',
                 'eV / (s m2 Hz)',
                 'erg / (s cm2)',
                 'erg / (s cm2 um)',
                 'erg / (s cm2 Angstrom)',
                 'erg / (s cm2 Hz)',
                 'ph / (s cm2 um)',
                 'ph / (s cm2 Angstrom)',
                 'ph / (s cm2 Hz)']
        return units

    @staticmethod
    def _locally_defined_spectral_axis_units():
        """
        list of defined spectral flux density units.
        """
        units = ['angstrom', 'nanometer',
                 'micron', 'hertz', 'erg']
        return units

    def convert_units_to_strings(self, unit_list):
        """
        Convert equivalencies into readable versions of the units.

        Parameters
        ----------
        unit_list : list
            List of either `astropy.units` or strings that can be converted
            to `astropy.units`.

        Returns
        -------
        list
            A list of the units with their best (i.e., most readable) string version.
        """
        return [u.Unit(unit).name
            if u.Unit(unit) == u.Unit("Angstrom")
            else u.Unit(unit).long_names[0] if (
                        hasattr(u.Unit(unit), "long_names") and len(u.Unit(unit).long_names) > 0)
            else u.Unit(unit).to_string()
            for unit in unit_list]
示例#18
0
class SlitOverlay(TemplateMixin):
    template = load_template("slit_overlay.vue", __file__).tag(sync=True)
    visible = Bool(True).tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        table = self.app.get_viewer("table-viewer")
        table.figure_widget.observe(self.place_slit_overlay,
                                    names=['highlighted'])

        self._slit_overlay_mark = None

    def vue_change_visible(self, *args, **kwargs):
        if self.visible:
            self.place_slit_overlay()
        else:
            self.remove_slit_overlay()

    def place_slit_overlay(self, *args, **kwargs):
        """
        Find slit information in 2D Spectrum metadata, find the correct
        wcs information from the image metadata, then plot the slit over the
        image viewer using both.
        """
        if not self.visible:
            return

        snackbar_message = None

        # Clear existing slits on the image viewer
        self.remove_slit_overlay()

        # Get data from relevant viewers
        image_data = self.app.get_viewer("image-viewer").state.reference_data
        spec2d_data = self.app.get_viewer("spectrum-2d-viewer").data()

        # 'S_REGION' contains slit information. Bypass in case no images exist.
        if image_data is not None:
            # Only use S_REGION for Nirspec data, turn the plugin off
            # if other data is loaded
            if (len(spec2d_data) > 0 and 'S_REGION' in spec2d_data[0].meta
                    and spec2d_data[0].meta.get('INSTRUME',
                                                '').lower() == "nirspec"):
                header = spec2d_data[0].meta
                sky_region = jwst_header_to_skyregion(header)

                # Use wcs of image viewer to scale slit dimensions correctly
                pixel_region = sky_region.to_pixel(image_data.coords)

                # Create polygon region from the pixel region and set vertices
                pix_rec = pixel_region.to_polygon()

                x_coords = pix_rec.vertices.x
                y_coords = pix_rec.vertices.y

                fig_image = self.app.get_viewer("image-viewer").figure

                if self.app.get_viewer(
                        "image-viewer").toolbar.active_tool is not None:
                    self.app.get_viewer(
                        "image-viewer").toolbar.active_tool = None

                # Create LinearScale that is the same size as the image viewer
                scales = {
                    'x': fig_image.interaction.x_scale,
                    'y': fig_image.interaction.y_scale
                }

                # Create slit
                patch2 = bqplot.Lines(x=x_coords,
                                      y=y_coords,
                                      scales=scales,
                                      fill='none',
                                      colors=["red"],
                                      stroke_width=2,
                                      close_path=True)

                # Visualize slit on the figure
                fig_image.marks = fig_image.marks + [patch2]

                self._slit_overlay_mark = patch2

            else:
                self.visible = False
                snackbar_message = SnackbarMessage(
                    "\'S_REGION\' not found in Spectrum 2D meta attribute, "
                    "turning slit overlay off",
                    color="warning",
                    sender=self)

        if snackbar_message:
            self.hub.broadcast(snackbar_message)

    def remove_slit_overlay(self):
        if self._slit_overlay_mark is not None:
            image_figure = self.app.get_viewer("image-viewer").figure
            # We need to do the following instead of just removing directly on
            # the marks otherwise traitlets doesn't register a change in the
            # marks.
            marks = image_figure.marks.copy()
            marks.remove(self._slit_overlay_mark)
            image_figure.marks = marks
            self._slit_overlay_mark = None
示例#19
0
class ModelFitting(TemplateMixin):
    dialog = Bool(False).tag(sync=True)
    template = load_template("model_fitting.vue", __file__).tag(sync=True)
    dc_items = List([]).tag(sync=True)

    save_enabled = Bool(False).tag(sync=True)
    model_label = Unicode().tag(sync=True)
    model_save_path = Unicode().tag(sync=True)
    temp_name = Unicode().tag(sync=True)
    temp_model = Unicode().tag(sync=True)
    model_equation = Unicode().tag(sync=True)
    eq_error = Bool(False).tag(sync=True)
    component_models = List([]).tag(sync=True)
    display_order = Bool(False).tag(sync=True)
    poly_order = Int(0).tag(sync=True)

    available_models = List(list(MODELS.keys())).tag(sync=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._viewer_spectra = None
        self._spectrum1d = None
        self._units = {}
        self.n_models = 0
        self._fitted_model = None
        self._fitted_spectrum = None
        self.component_models = []
        self._initialized_models = {}
        self._display_order = False
        self.model_save_path = os.getcwd()
        self.model_label = "Model"
        self._selected_data_label = None

        self.hub.subscribe(self,
                           AddDataMessage,
                           handler=self._on_viewer_data_changed)

        self.hub.subscribe(self,
                           RemoveDataMessage,
                           handler=self._on_viewer_data_changed)

        self.hub.subscribe(self,
                           SubsetCreateMessage,
                           handler=lambda x: self._on_viewer_data_changed())

        self.hub.subscribe(self,
                           SubsetDeleteMessage,
                           handler=lambda x: self._on_viewer_data_changed())

        self.hub.subscribe(self,
                           SubsetUpdateMessage,
                           handler=lambda x: self._on_viewer_data_changed())

    def _on_viewer_data_changed(self, msg=None):
        """
        Callback method for when data is added or removed from a viewer, or
        when a subset is created, deleted, or updated. This method receives
        a glue message containing viewer information in the case of the former
        set of events, and updates the available data list displayed to the
        user.

        Notes
        -----
        We do not attempt to parse any data at this point, at it can cause
        visible lag in the application.

        Parameters
        ----------
        msg : `glue.core.Message`
            The glue message passed to this callback method.
        """
        self._viewer_id = self.app._viewer_item_by_reference(
            'spectrum-viewer').get('id')

        # Subsets are global and are not linked to specific viewer instances,
        # so it's not required that we match any specific ids for that case.
        # However, if the msg is not none, check to make sure that it's the
        # viewer we care about.
        if msg is not None and msg.viewer_id != self._viewer_id:
            return

        viewer = self.app.get_viewer('spectrum-viewer')

        self.dc_items = [
            layer_state.layer.label for layer_state in viewer.state.layers
        ]

    def _param_units(self, param, order=0):
        """Helper function to handle units that depend on x and y"""
        y_params = ["amplitude", "amplitude_L", "intercept"]

        if param == "slope":
            return str(u.Unit(self._units["y"]) / u.Unit(self._units["x"]))
        elif param == "poly":
            return str(
                u.Unit(self._units["y"]) / u.Unit(self._units["x"])**order)

        return self._units["y"] if param in y_params else self._units["x"]

    def _update_parameters_from_fit(self):
        """Insert the results of the model fit into the component_models"""
        for m in self.component_models:
            name = m["id"]
            if len(self.component_models) > 1:
                m_fit = self._fitted_model[name]
            else:
                m_fit = self._fitted_model
            temp_params = []
            for i in range(0, len(m_fit.parameters)):
                temp_param = [
                    x for x in m["parameters"]
                    if x["name"] == m_fit.param_names[i]
                ]
                temp_param[0]["value"] = m_fit.parameters[i]
                temp_params += temp_param
            m["parameters"] = temp_params

        # Trick traitlets into updating the displayed values
        component_models = self.component_models
        self.component_models = []
        self.component_models = component_models

    def _update_parameters_from_QM(self):
        """
        Parse out result parameters from a QuantityModel, which isn't
        subscriptable with model name
        """
        if hasattr(self._fitted_model, "submodel_names"):
            submodel_names = self._fitted_model.submodel_names
            submodels = True
        else:
            submodel_names = [self._fitted_model.name]
            submodels = False
        fit_params = self._fitted_model.parameters
        param_names = self._fitted_model.param_names

        for i in range(len(submodel_names)):
            name = submodel_names[i]
            m = [x for x in self.component_models if x["id"] == name][0]
            temp_params = []
            if submodels:
                idxs = [
                    j for j in range(len(param_names))
                    if int(param_names[j][-1]) == i
                ]
            else:
                idxs = [j for j in range(len(param_names))]
            # This is complicated by needing to handle parameter names that
            # have underscores in them, since QuantityModel adds an underscore
            # and integer to indicate to which model a parameter belongs
            for idx in idxs:
                if submodels:
                    temp_param = [
                        x for x in m["parameters"] if x["name"] == "_".join(
                            param_names[idx].split("_")[0:-1])
                    ]
                else:
                    temp_param = [
                        x for x in m["parameters"]
                        if x["name"] == param_names[idx]
                    ]
                temp_param[0]["value"] = fit_params[idx]
                temp_params += temp_param
            m["parameters"] = temp_params

        # Trick traitlets into updating the displayed values
        component_models = self.component_models
        self.component_models = []
        self.component_models = component_models

    def _update_initialized_parameters(self):
        # If the user changes a parameter value, we need to change it in the
        # initialized model
        for m in self.component_models:
            name = m["id"]
            for param in m["parameters"]:
                quant_param = u.Quantity(param["value"], param["unit"])
                setattr(self._initialized_models[name], param["name"],
                        quant_param)

    def _warn_if_no_equation(self):
        if self.model_equation == "" or self.model_equation is None:
            example = "+".join([m["id"] for m in self.component_models])
            snackbar_message = SnackbarMessage(
                f"Error: a model equation must be defined, e.g. {example}",
                color='error',
                sender=self)
            self.hub.broadcast(snackbar_message)
            return True
        else:
            return False

    def vue_data_selected(self, event):
        """
        Callback method for when the user has selected data from the drop down
        in the front-end. It is here that we actually parse and create a new
        data object from the selected data. From this data object, unit
        information is scraped, and the selected spectrum is stored for later
        use in fitting.

        Parameters
        ----------
        event : str
            IPyWidget callback event object. In this case, represents the data
            label of the data collection object selected by the user.
        """
        selected_spec = self.app.get_data_from_viewer("spectrum-viewer",
                                                      data_label=event)
        # Replace NaNs from collapsed SpectralCube in Cubeviz
        # (won't affect calculations because these locations are masked)
        selected_spec.flux[np.isnan(selected_spec.flux)] = 0.0

        self._selected_data_label = event

        if self._units == {}:
            self._units["x"] = str(selected_spec.spectral_axis.unit)
            self._units["y"] = str(selected_spec.flux.unit)

        self._spectrum1d = selected_spec

    def vue_model_selected(self, event):
        # Add the model selected to the list of models
        self.temp_model = event
        if event == "Polynomial1D":
            self.display_order = True
        else:
            self.display_order = False

    def _initialize_polynomial(self, new_model):
        initialized_model = initialize(
            MODELS[self.temp_model](name=self.temp_name,
                                    degree=self.poly_order),
            self._spectrum1d.spectral_axis, self._spectrum1d.flux)

        self._initialized_models[self.temp_name] = initialized_model
        new_model["order"] = self.poly_order

        for i in range(self.poly_order + 1):
            param = "c{}".format(i)
            initial_val = getattr(initialized_model, param).value
            new_model["parameters"].append({
                "name": param,
                "value": initial_val,
                "unit": self._param_units("poly", i),
                "fixed": False
            })

        self._update_initialized_parameters()

        return new_model

    def _reinitialize_with_fixed(self):
        """
        Reinitialize all component models with current values and the
        specified parameters fixed (can't easily update fixed dictionary in
        an existing model)
        """
        temp_models = []
        for m in self.component_models:
            fixed = {}
            for p in m["parameters"]:
                fixed[p["name"]] = p["fixed"]
            # Have to initialize with fixed dictionary
            if m["model_type"] == "Polynomial1D":
                temp_model = MODELS[m["model_type"]](name=m["id"],
                                                     degree=m["order"],
                                                     fixed=fixed)
            else:
                temp_model = MODELS[m["model_type"]](name=m["id"], fixed=fixed)
            # Now we can set the parameter values
            for p in m["parameters"]:
                setattr(temp_model, p["name"], p["value"])
            temp_models.append(temp_model)
        return temp_models

    def vue_add_model(self, event):
        """Add the selected model and input string ID to the list of models"""
        new_model = {
            "id": self.temp_name,
            "model_type": self.temp_model,
            "parameters": []
        }

        # Need to do things differently for polynomials, since the order varies
        if self.temp_model == "Polynomial1D":
            new_model = self._initialize_polynomial(new_model)
        else:
            # Have a separate private dict with the initialized models, since
            # they don't play well with JSON for widget interaction
            initialized_model = initialize(
                MODELS[self.temp_model](name=self.temp_name),
                self._spectrum1d.spectral_axis, self._spectrum1d.flux)

            self._initialized_models[self.temp_name] = initialized_model

            for param in model_parameters[new_model["model_type"]]:
                initial_val = getattr(initialized_model, param).value
                new_model["parameters"].append({
                    "name": param,
                    "value": initial_val,
                    "unit": self._param_units(param),
                    "fixed": False
                })

        new_model["Initialized"] = True
        self.component_models = self.component_models + [new_model]

        self._update_initialized_parameters()

    def vue_remove_model(self, event):
        self.component_models = [
            x for x in self.component_models if x["id"] != event
        ]
        del (self._initialized_models[event])

    def vue_save_model(self, event):
        if self.model_save_path[-1] == "/":
            connector = ""
        else:
            connector = "/"
        full_path = self.model_save_path + connector + self.model_label + ".pkl"
        with open(full_path, 'wb') as f:
            pickle.dump(self._fitted_model, f)

    def vue_equation_changed(self, event):
        # Length is a dummy check to test the infrastructure
        if len(self.model_equation) > 20:
            self.eq_error = True

    def vue_model_fitting(self, *args, **kwargs):
        """
        Run fitting on the initialized models, fixing any parameters marked
        as such by the user, then update the displayed parameters with fit
        values
        """
        if self._warn_if_no_equation():
            return
        models_to_fit = self._reinitialize_with_fixed()

        try:
            fitted_model, fitted_spectrum = fit_model_to_spectrum(
                self._spectrum1d,
                models_to_fit,
                self.model_equation,
                run_fitter=True)
        except AttributeError:
            msg = SnackbarMessage(
                "Unable to fit: model equation may be invalid",
                color="error",
                sender=self)
            self.hub.broadcast(msg)
            return
        self._fitted_model = fitted_model
        self._fitted_spectrum = fitted_spectrum

        self.vue_register_spectrum({"spectrum": fitted_spectrum})
        if not hasattr(self.app, "_fitted_1d_models"):
            self.app._fitted_1d_models = {}
        self.app._fitted_1d_models[self.model_label] = fitted_model

        # Update component model parameters with fitted values
        if type(self._fitted_model) == QuantityModel:
            self._update_parameters_from_QM()
        else:
            self._update_parameters_from_fit()

        self.save_enabled = True

    def vue_fit_model_to_cube(self, *args, **kwargs):

        if self._warn_if_no_equation():
            return
        data = self.app.data_collection[self._selected_data_label]

        # First, ensure that the selected data is cube-like. It is possible
        # that the user has selected a pre-existing 1d data object.
        if data.ndim != 3:
            snackbar_message = SnackbarMessage(
                f"Selected data {self._selected_data_label} is not cube-like",
                color='error',
                sender=self)
            self.hub.broadcast(snackbar_message)
            return

        # Get the primary data component
        attribute = data.main_components[0]
        component = data.get_component(attribute)
        temp_values = data.get_data(attribute)

        # Transpose the axis order
        values = np.moveaxis(temp_values, 0, -1) * u.Unit(component.units)

        # We manually create a Spectrum1D object from the flux information
        #  in the cube we select
        wcs = data.coords.sub([WCSSUB_SPECTRAL])
        spec = Spectrum1D(flux=values, wcs=wcs)

        # TODO: in vuetify >2.3, timeout should be set to -1 to keep open
        #  indefinitely
        snackbar_message = SnackbarMessage("Fitting model to cube...",
                                           loading=True,
                                           timeout=0,
                                           sender=self)
        self.hub.broadcast(snackbar_message)

        # Retrieve copy of the models with proper "fixed" dictionaries
        # TODO: figure out why this was causing the parallel fitting to fail
        #models_to_fit = self._reinitialize_with_fixed()
        models_to_fit = self._initialized_models.values()

        fitted_model, fitted_spectrum = fit_model_to_spectrum(
            spec, models_to_fit, self.model_equation, run_fitter=True)

        # Save fitted 3D model in a way that the cubeviz
        # helper can access it.
        self.app._fitted_3d_model = fitted_model

        # Transpose the axis order back
        values = np.moveaxis(fitted_spectrum.flux.value, -1, 0)

        count = max(
            map(lambda s: int(next(iter(re.findall("\d$", s)), 0)),
                self.data_collection.labels)) + 1

        label = f"{self.model_label} [Cube] {count}"

        # Create new glue data object
        output_cube = Data(label=label, coords=data.coords)
        output_cube['flux'] = values
        output_cube.get_component('flux').units = \
            fitted_spectrum.flux.unit.to_string()

        # Add to data collection
        self.app.data_collection.append(output_cube)

        snackbar_message = SnackbarMessage("Finished cube fitting",
                                           color='success',
                                           loading=False,
                                           sender=self)
        self.hub.broadcast(snackbar_message)

    def vue_register_spectrum(self, event):
        """
        Add a spectrum to the data collection based on the currently displayed
        parameters (these could be user input or fit values).
        """
        if self._warn_if_no_equation():
            return
        # Make sure the initialized models are updated with any user-specified
        # parameters
        self._update_initialized_parameters()

        # Need to run the model fitter with run_fitter=False to get spectrum
        if "spectrum" in event:
            spectrum = event["spectrum"]
        else:
            model, spectrum = fit_model_to_spectrum(
                self._spectrum1d, self._initialized_models.values(),
                self.model_equation)

        self.n_models += 1
        label = self.model_label
        if label in self.data_collection:
            self.app.remove_data_from_viewer('spectrum-viewer', label)
            # Remove the actual Glue data object from the data_collection
            self.data_collection.remove(self.data_collection[label])
        self.data_collection[label] = spectrum
        self.save_enabled = True