Exemple #1
0
    def __init__(self,
                 displayer: Callable[[Figure, Selection], Callable[[int, Any],
                                                                   None]],
                 resolution: int = None,
                 fig_kw={},
                 children=(),
                 toolbar=(),
                 **ka):
        from ipywidgets import BoundedIntText, IntText, Label

        def select(bounds_):
            i = self.level + 1
            w_level.active = False
            w_level.max = i
            w_level.value = i
            setlevel(i, (resolution, bounds_))
            w_level.active = True

        def show_running(b):
            w_running.icon = 'pause' if b else 'play'

        def show_precision(p):
            w_precision.value = p

        self.select = select
        self.show_precision = show_precision
        self.show_running = show_running

        def setlevel(i, new=None):
            self.level = i
            w_precision.value = display(i, new)
            board.canvas.draw_idle()

        # global design and widget definitions
        w_running = SimpleButton(icon='')
        w_level = BoundedIntText(0,
                                 min=0,
                                 max=0,
                                 layout=dict(width='1.6cm', padding='0cm'))
        w_level.active = True
        w_precision = IntText(0,
                              disabled=True,
                              layout=dict(width='1.6cm', padding='0cm'))
        super().__init__(children,
                         toolbar=(w_running, Label('level:'), w_level,
                                  Label('precision:'), w_precision, *toolbar))
        self.board = board = self.mpl_figure(**fig_kw)
        display = displayer(board, select)
        # callbacks
        w_running.on_click(lambda b: self.setrunning())
        w_level.observe(
            (lambda c: (setlevel(c.new) if w_level.active else None)), 'value')
        super(app, self).__init__(display, **ka)
Exemple #2
0
    def find_beam_shape_thresh(self):
        if self._beam_shape_fig is None:
            fig, ax = plt.subplots()
            self._beam_shape_fig = fig

        ax = self._beam_shape_fig.axes[0]

        thresh_sel = BoundedIntText(value=50, min=0, max=100,
                                    description="threshold")

        def update(change=None):
            if self._beam_shape_artist is not None:
                self._beam_shape_artist.remove()
            self._beam_shape_artist = ax.imshow(
                self.beam_shape * 100 > thresh_sel.value)
            self._beam_shape_fig.canvas.draw_idle()

        thresh_sel.observe(update, "value")
        update()

        return VBox([thresh_sel, self._beam_shape_fig.canvas])
Exemple #3
0
    def build_options(self):
        grid = GridspecLayout(10, 2)
        options_map = {}
        style = {'description_width': '60%', 'width': 'auto'}

        # feature
        feature = Combobox(description='Feature to plot:',
                           style=style,
                           options=list(self.feature_names),
                           ensure_option=True,
                           value=self.feature_names[0])
        options_map['feature'] = feature

        # num_grid_points
        num_grid_points = BoundedIntText(
            value=10,
            min=1,
            max=999999,
            step=1,
            description='Number of grid points:',
            style=style,
            description_tooltip='Number of grid points for numeric feature')
        options_map['num_grid_points'] = num_grid_points

        # grid_type
        grid_type = Dropdown(
            description='Grid type:',
            options=['percentile', 'equal'],
            style=style,
            description_tooltip='Type of grid points for numeric feature')
        options_map['grid_type'] = grid_type

        # cust_range
        cust_range = Checkbox(description='Custom grid range', value=False)
        options_map['cust_range'] = cust_range

        # range_min
        range_min = FloatText(
            description='Custom range minimum:',
            style=style,
            description_tooltip=
            'Percentile (when grid_type="percentile") or value (when grid_type="equal") '
            'lower bound of range to investigate (for numeric feature)\n'
            ' - Enabled only when custom grid range is True and variable with grid points is None',
            disabled=True)
        options_map['range_min'] = range_min

        # range_max
        range_max = FloatText(
            description='Custom range maximum:',
            style=style,
            description_tooltip=
            'Percentile (when grid_type="percentile") or value (when grid_type="equal") '
            'upper bound of range to investigate (for numeric feature)\n'
            ' - Enabled only when custom grid range is True and variable with grid points is None',
            disabled=True)
        options_map['range_max'] = range_max

        # cust_grid_points
        cust_grid_points = UpdatingCombobox(
            options_keys=self.globals_options,
            description='Variable with grid points:',
            style=style,
            description_tooltip=
            'Name of variable (or None) with customized list of grid points for numeric feature',
            value='None',
            disabled=True)
        cust_grid_points.lookup_in_kernel = True
        options_map['cust_grid_points'] = cust_grid_points

        # set up disabling of range inputs, when user doesn't want custom range
        def disable_ranges(change):
            range_min.disabled = not change['new']
            range_max.disabled = not change['new']
            cust_grid_points.disabled = not change['new']
            # but if the cust_grid_points has a value filled in keep range_max and range_min disabled
            if cust_grid_points.value != 'None':
                range_max.disabled = True
                range_min.disabled = True

        cust_range.observe(disable_ranges, names=['value'])

        # set up disabling of range_max and range_min if user specifies custom grid points
        def disable_max_min(change):
            if change['new'] == 'None':
                range_max.disabled = False
                range_min.disabled = False
            else:
                range_max.disabled = True
                range_min.disabled = True

        cust_grid_points.observe(disable_max_min, names=['value'])

        # set up links between upper and lower ranges
        def set_ranges(change):
            if grid_type.value == 'percentile':
                if change['owner'] == range_min or change[
                        'owner'] == num_grid_points:
                    range_max.value = max(
                        range_max.value,
                        range_min.value + num_grid_points.value)
                if change['owner'] == range_max:
                    range_min.value = min(
                        range_min.value,
                        range_max.value - num_grid_points.value)
            else:
                if change['owner'] == range_min:
                    range_max.value = max(range_max.value, range_min.value)
                if change['owner'] == range_max:
                    range_min.value = min(range_min.value, range_max.value)

        range_min.observe(set_ranges, names=['value'])
        range_max.observe(set_ranges, names=['value'])
        num_grid_points.observe(set_ranges, names=['value'])

        # center
        center = Checkbox(description='Center the plot', value=True)
        options_map['center'] = center

        # plot_pts_dist
        plot_pts_dist = Checkbox(description='Plot data points distribution',
                                 value=True)
        options_map['plot_pts_dist'] = plot_pts_dist

        # x_quantile
        x_quantile = Checkbox(description='X-axis as quantiles', value=False)
        options_map['x_quantile'] = x_quantile

        # show_percentile
        show_percentile = Checkbox(description='Show precentile buckets',
                                   value=False)
        options_map['show_percentile'] = show_percentile

        # lines
        lines = Checkbox(description='Plot lines - ICE plot', value=False)
        options_map['lines'] = lines

        # frac_to_plot
        frac_to_plot = BoundedFloatText(
            description='Lines to plot:',
            value=1,
            description_tooltip=
            'How many lines to plot, can be a integer or a float.\n'
            ' - integer values higher than 1 are interpreted as absolute amount\n'
            ' - floats are interpreted as fraction (e.g. 0.5 means half of all possible lines)',
            style=style,
            disabled=True)
        options_map['frac_to_plot'] = frac_to_plot

        # cluster
        cluster = Checkbox(description='Cluster lines',
                           value=False,
                           disabled=True)
        options_map['cluster'] = cluster

        # n_cluster_centers
        n_cluster_centers = BoundedIntText(
            value=10,
            min=1,
            max=999999,
            step=1,
            description='Number of cluster centers:',
            style=style,
            description_tooltip='Number of cluster centers for lines',
            disabled=True)
        options_map['n_cluster_centers'] = n_cluster_centers

        # cluster method
        cluster_method = Dropdown(
            description='Cluster method',
            style=style,
            options={
                'KMeans': 'accurate',
                'MiniBatchKMeans': 'approx'
            },
            description_tooltip='Method to use for clustering of lines',
            disabled=True)
        options_map['cluster_method'] = cluster_method

        # set up disabling of lines related options
        def disable_lines(change):
            frac_to_plot.disabled = not change['new']
            cluster.disabled = not change['new']
            n_cluster_centers.disabled = not (change['new'] and cluster.value)
            cluster_method.disabled = not (change['new'] and cluster.value)

        lines.observe(disable_lines, names=['value'])

        # set up disabling of clustering options
        def disable_clustering(change):
            n_cluster_centers.disabled = not (cluster.value and change['new'])
            cluster_method.disabled = not (cluster.value and change['new'])

        cluster.observe(disable_clustering, names=['value'])

        grid[0, :] = feature
        grid[1, 0] = num_grid_points
        grid[1, 1] = grid_type
        grid[2, 0] = cust_range
        grid[2, 1] = cust_grid_points
        grid[3, 0] = range_min
        grid[3, 1] = range_max
        grid[4, 0] = center
        grid[4, 1] = plot_pts_dist
        grid[5, 0] = x_quantile
        grid[5, 1] = show_percentile
        grid[6, :] = lines
        grid[7, :] = frac_to_plot
        grid[8, :] = cluster
        grid[9, 0] = n_cluster_centers
        grid[9, 1] = cluster_method

        return options_map, grid
Exemple #4
0
class DatasetAnnotatorClassification:
    supported_types = [
        TaskType.CLASSIFICATION_BINARY, TaskType.CLASSIFICATION_SINGLE_LABEL,
        TaskType.CLASSIFICATION_MULTI_LABEL
    ]

    def __init__(self,
                 task_type,
                 observations,
                 output_path,
                 name,
                 classes,
                 show_name=True,
                 show_axis=False,
                 fig_size=(10, 10),
                 buttons_vertical=False,
                 custom_display_function=None,
                 is_image=True):

        if task_type not in self.supported_types:
            raise Exception(labels_str.warn_task_not_supported)

        if len(observations) == 0:
            raise Exception(labels_str.warn_no_images)

        num_classes = len(classes)
        if num_classes <= 1:
            raise Exception(labels_str.warn_little_classes)

        elif len(
                classes
        ) > 2 and task_type.value == TaskType.CLASSIFICATION_BINARY.value:
            raise Exception(labels_str.warn_binary_only_two)

        self.is_image = is_image
        self.key = "path" if self.is_image else "observation"
        if not self.is_image and custom_display_function is None:
            raise Exception(labels_str.warn_display_function_needed)

        self.task_type = task_type
        self.show_axis = show_axis
        self.name = name
        self.show_name = show_name
        self.output_path = output_path
        self.file_path = os.path.join(self.output_path, self.name + ".json")
        print(labels_str.info_ds_output + self.file_path)
        self.mapping, self.dataset = self.__create_results_dict(
            self.file_path, classes)

        self.classes = list(self.mapping["categories_id"].values())

        if len(
                self.classes
        ) > 2 and task_type.value == TaskType.CLASSIFICATION_BINARY.value:
            raise Exception(labels_str.warn_binary_only_two +
                            " ".join(self.classes))

        self.observations = observations
        self.max_pos = len(self.observations) - 1
        self.pos = 0
        self.fig_size = fig_size
        self.buttons_vertical = buttons_vertical

        if custom_display_function is None:
            self.image_display_function = self.__show_image
        else:
            self.image_display_function = custom_display_function

        self.previous_button = self.__create_button(labels_str.str_btn_prev,
                                                    (self.pos == 0),
                                                    self.__on_previous_clicked)
        self.next_button = self.__create_button(labels_str.str_btn_next,
                                                (self.pos == self.max_pos),
                                                self.__on_next_clicked)
        self.save_button = self.__create_button(labels_str.str_btn_download,
                                                False, self.__on_save_clicked)
        self.save_function = self.__save_function  # save_function

        buttons = [self.previous_button, self.next_button, self.save_button]

        label_total = Label(value='/ {}'.format(len(self.observations)))
        self.text_index = BoundedIntText(value=1,
                                         min=1,
                                         max=len(self.observations))
        self.text_index.layout.width = '80px'
        self.text_index.layout.height = '35px'
        self.text_index.observe(self.__selected_index)
        self.out = Output()
        self.out.add_class(name)

        if self.__is_multilabel():
            self.checkboxes = [
                Checkbox(False,
                         description='{}'.format(self.classes[i]),
                         indent=False) for i in range(len(self.classes))
            ]
            for cb in self.checkboxes:
                cb.layout.width = '180px'
                cb.observe(self.__checkbox_changed)
            self.checkboxes_layout = VBox(
                children=[cb for cb in self.checkboxes])
        else:
            self.checkboxes = RadioButtons(options=self.classes,
                                           disabled=False,
                                           indent=False)
            self.checkboxes.layout.width = '180px'
            self.checkboxes.observe(self.__checkbox_changed)
            self.checkboxes_layout = VBox(children=[self.checkboxes])

        output_layout = HBox(children=[self.out, self.checkboxes_layout])
        if self.buttons_vertical:
            self.all_widgets = HBox(children=[
                VBox(children=[HBox([self.text_index, label_total])] +
                     buttons), output_layout
            ])
        else:
            self.all_widgets = VBox(children=[
                HBox([self.text_index, label_total]),
                HBox(children=buttons), output_layout
            ])

        ## loading js library to perform html screenshots
        j_code = """
                require.config({
                    paths: {
                        html2canvas: "https://html2canvas.hertzen.com/dist/html2canvas.min"
                    }
                });
            """
        display(Javascript(j_code))

    def __create_results_dict(self, file_path, cc):
        mapping = {
            "categories_id": {},
            "categories_name": {},
            "observations": {}
        }

        if not os.path.exists(file_path):
            dataset = {'categories': [], "observations": []}
            for index, c in enumerate(cc):
                category = {"supercategory": c, "name": c, "id": index + 1}
                dataset["categories"].append(category)
        else:
            with open(file_path, 'r') as classification_file:
                dataset = json.load(classification_file)
            for index, img in enumerate(dataset['observations']):
                mapping['observations'][img[self.key]] = index

        for c in dataset['categories']:
            mapping['categories_id'][c["id"]] = c["name"]
            mapping['categories_name'][c["name"]] = c["id"]
        index_categories = len(dataset['categories']) + 1

        for c in cc:
            if not c in mapping['categories_name'].keys():
                mapping['categories_id'][index_categories] = c
                mapping['categories_name'][c] = index_categories
                category = {
                    "supercategory": c,
                    "name": c,
                    "id": index_categories
                }
                dataset["categories"].append(category)
                index_categories += 1

        return mapping, dataset

    def __checkbox_changed(self, b):

        if b['owner'].value is None or b['name'] != 'value':
            return

        class_name = b['owner'].description
        value = b['owner'].value
        current_index = self.mapping["observations"][self.observations[
            self.pos]]

        if self.__is_multilabel():
            class_index = self.mapping["categories_name"][class_name]
            if not class_index in self.dataset["observations"][current_index][
                    "categories"] and value:
                self.dataset["observations"][current_index][
                    "categories"].append(class_index)
            if class_index in self.dataset["observations"][current_index][
                    "categories"] and not value:
                self.dataset["observations"][current_index][
                    "categories"].remove(class_index)
        else:
            class_index = self.mapping["categories_name"][value]
            self.dataset["observations"][current_index][
                "category"] = class_index

        if self.pos == self.max_pos:
            self.save_state()

    def __is_multilabel(self):
        return TaskType.CLASSIFICATION_MULTI_LABEL.value == self.task_type.value

    def __create_button(self, description, disabled, function):
        button = Button(description=description)
        button.disabled = disabled
        button.on_click(function)
        return button

    def __show_image(self, image_record):
        if not 'path' in image_record:
            print("missing path")
        if not os.path.exists(image_record['path']):
            print("Image cannot be load" + image_record['path'])

        img = Image.open(image_record['path'])
        if self.show_name:
            print(os.path.basename(image_record['path']))
        plt.figure(figsize=self.fig_size)
        if not self.show_axis:
            plt.axis('off')
        plt.imshow(img)
        plt.show()

    def save_state(self):
        with open(self.file_path, 'w') as output_file:
            json.dump(self.dataset, output_file, indent=4)

    def __save_function(self, image_path, index):
        img_name = os.path.basename(image_path).split('.')[0]
        j_code = """
            require(["html2canvas"], function(html2canvas) {
                var element = $(".p-Widget.jupyter-widgets-output-area.output_wrapper.$it_name$")[0];
                console.log(element);
                 html2canvas(element).then(function (canvas) { 
                    var myImage = canvas.toDataURL(); 
                    var a = document.createElement("a"); 
                    a.href = myImage; 
                    a.download = "$img_name$.png"; 
                    a.click(); 
                    a.remove(); 
                });
            });
            """
        j_code = j_code.replace('$it_name$', self.name)
        j_code = j_code.replace('$img_name$', img_name)
        tmp_out = Output()
        with tmp_out:
            display(Javascript(j_code))
            tmp_out.clear_output()

    def __perform_action(self):

        if not self.observations[
                self.pos] in self.mapping["observations"].keys():
            observation = {self.key: self.observations[self.pos]}
            if self.is_image:
                observation["file_name"] = os.path.basename(
                    self.observations[self.pos])

            observation["id"] = len(self.mapping["observations"]) + 1
            if self.__is_multilabel():
                observation["categories"] = []
            self.dataset["observations"].append(observation)

            self.mapping["observations"][observation[self.key]] = len(
                self.dataset["observations"]) - 1

        current_index = self.mapping["observations"][self.observations[
            self.pos]]
        self.next_button.disabled = (self.pos == self.max_pos)
        self.previous_button.disabled = (self.pos == 0)

        if self.__is_multilabel():
            for cb in self.checkboxes:
                cb.unobserve(self.__checkbox_changed)
                if not "categories" in self.dataset["observations"][
                        current_index]:
                    self.dataset["observations"][current_index][
                        "categories"] = []
                categories = self.dataset["observations"][current_index][
                    "categories"]
                cb.value = self.mapping["categories_name"][
                    cb.description] in categories
                cb.observe(self.__checkbox_changed)
        else:
            self.checkboxes.unobserve(self.__checkbox_changed)
            obs = self.dataset["observations"][current_index]
            category = obs["category"] if "category" in obs else None

            if category:
                for k in self.mapping["categories_name"]:
                    if self.mapping["categories_name"][k] == category:
                        category = k
                        break
            self.checkboxes.value = category
            self.checkboxes.observe(self.__checkbox_changed)

        with self.out:
            self.out.clear_output()
            self.image_display_function(
                self.dataset["observations"][current_index])

        self.text_index.unobserve(self.__selected_index)
        self.text_index.value = self.pos + 1
        self.text_index.observe(self.__selected_index)

    def __on_previous_clicked(self, b):
        self.save_state()
        self.pos -= 1
        self.__perform_action()

    def __on_next_clicked(self, b):
        self.save_state()
        self.pos += 1
        self.__perform_action()

    def __on_save_clicked(self, b):
        self.save_state()
        self.save_function(self.observations[self.pos], self.pos)

    def __selected_index(self, t):
        if t['owner'].value is None or t['name'] != 'value':
            return
        self.pos = t['new'] - 1
        self.__perform_action()

    def start_classification(self):
        if self.max_pos < self.pos:
            print("No available observation")
            return
        display(self.all_widgets)
        self.__perform_action()

    def print_statistics(self):
        counter = dict()
        for c in self.mapping["categories_id"]:
            counter[c] = 0

        for record in self.dataset["observations"]:

            if self.__is_multilabel():
                if "categories" in record:
                    for c in record["categories"]:
                        counter[c] += 1
            elif "category" in record:
                counter[record["category"]] += 1

        table = []
        for c in counter:
            table.append([self.mapping["categories_id"][c], counter[c]])
        table = sorted(table, key=lambda x: x[0])

        print(
            tabulate(table,
                     headers=[
                         labels_str.info_class_name, labels_str.info_ann_images
                     ]))
class TwittipediaView:
    def __init__(self, controller: TwittipediaController):

        self.controller = controller

        title = HTML(value="<h1>Twittipedia</h1>",
                     description="",
                     disabled=False)

        # Twitter

        twitter_title = HTML(value="<b>Tweets</b>",
                             description="",
                             disabled=False)

        self.query_field = Text(value="",
                                placeholder="Search or enter username",
                                description="",
                                disabled=False,
                                layout=dict(width="auto"))
        self.query_field.observe(self._query_field_changed, names="value")

        self.number_of_tweets_field = BoundedIntText(
            value=DEFAULT_NUMBER_OF_TWEETS,
            min=1,
            max=100,
            step=1,
            description="",
            disabled=False,
            layout=dict(width="auto"))
        self.number_of_tweets_field.observe(
            self._number_of_tweets_field_changed, names="value")
        number_of_tweets_label = Label(value="most recent tweets",
                                       disabled=False,
                                       layout=dict(width="auto"))
        number_of_tweets_field_with_label = HBox(
            (self.number_of_tweets_field, number_of_tweets_label))

        self.search_tweets_button = Button(description="Search Tweets",
                                           disabled=True,
                                           button_style="primary",
                                           tooltip="",
                                           icon="",
                                           layout=dict(width="auto"))
        self.load_tweets_from_user_button = Button(
            description="Load Tweets from User",
            disabled=True,
            button_style="",
            tooltip="",
            icon="",
            layout=dict(width="auto"))
        self.twitter_search_buttons = [
            self.load_tweets_from_user_button, self.search_tweets_button
        ]

        for button in self.twitter_search_buttons:
            button.on_click(self._twitter_search_button_pressed)

        twitter_search_buttons_box = Box(self.twitter_search_buttons,
                                         layout=dict(
                                             justify_content="flex-end",
                                             flex_flow="row wrap",
                                         ))

        self.reset_and_clear_tweets_button = Button(
            description="Reset and Clear Tweets",
            disabled=True,
            button_style="danger",
            tooltip="",
            icon="",
            layout=dict(width="auto"))
        self.reset_and_clear_tweets_button.on_click(
            self._reset_and_clear_tweets_button_pressed)

        self.twitter_buttons = self.twitter_search_buttons \
            + [self.reset_and_clear_tweets_button]
        twitter_buttons_box = Box(
            (self.reset_and_clear_tweets_button, twitter_search_buttons_box),
            layout=dict(
                justify_content="space-between",
                flex_flow="row wrap",
            ))

        twitter_box = VBox(
            (twitter_title, self.query_field,
             number_of_tweets_field_with_label, twitter_buttons_box))

        # Wikipedia search

        wikipedia_title = HTML(value="<b>Wikipedia Search</b>",
                               description="",
                               disabled=False)

        self.wikipedia_options = []

        self.term_frequency_scaling_parameter_field = BoundedFloatText(
            value=DEFAULT_TERM_FREQUENCY_SCALING_PARAMETER_VALUE,
            min=0,
            max=100,
            step=0.1,
            description="",
            disabled=False,
            layout=dict(width="auto"))
        self.term_frequency_scaling_parameter_field.observe(
            self._term_frequency_scaling_parameter_field_changed,
            names="value")
        self.wikipedia_options.append({
            "label": "$k_1$${{}}=$",
            "field": self.term_frequency_scaling_parameter_field,
            "explanation": "$k_1 \ge 0$"
        })

        self.document_length_scaling_parameter_field = BoundedFloatText(
            value=DEFAULT_DOCUMENT_LENGTH_SCALING_PARAMETER_VALUE,
            min=0,
            max=1,
            step=0.01,
            description="",
            disabled=False,
            layout=dict(width="auto"))
        self.document_length_scaling_parameter_field.observe(
            self._document_length_scaling_parameter_field_changed,
            names="value")
        self.wikipedia_options.append({
            "label": "$b$${{}}=$",
            "field": self.document_length_scaling_parameter_field,
            "explanation": "$0 \le b$${} \le 1$"
        })

        wikipedia_options_box = Box(
            (VBox([
                Label(value=option["label"])
                for option in self.wikipedia_options
            ],
                  layout=dict(align_items="flex-end")),
             VBox([option["field"] for option in self.wikipedia_options], ),
             VBox([
                 Label(value="(" + option["explanation"] + ")")
                 for option in self.wikipedia_options
             ])))

        self.search_wikipedia_button = Button(
            description="Search Wikipedia with Tweets",
            disabled=True,
            button_style="primary",
            tooltip="",
            icon="",
            layout=dict(width="auto"))
        self.search_wikipedia_button.on_click(
            self._search_wikipedia_button_pressed)

        self.reset_and_clear_wikipedia_results_button = Button(
            description="Reset and Clear Results",
            disabled=True,
            button_style="danger",
            tooltip="",
            icon="",
            layout=dict(width="auto"))
        self.reset_and_clear_wikipedia_results_button.on_click(
            self._reset_and_clear_wikipedia_results_button_pressed)

        self.wikipedia_buttons = [
            self.reset_and_clear_wikipedia_results_button,
            self.search_wikipedia_button
        ]

        wikipedia_buttons_box = Box(self.wikipedia_buttons,
                                    layout=dict(
                                        justify_content="space-between",
                                        flex_flow="row wrap",
                                    ))

        wikipedia_box = VBox(
            (wikipedia_title, wikipedia_options_box, wikipedia_buttons_box))

        # Result views

        results_title = HTML(value="<b>Results</b>",
                             description="",
                             disabled=True)

        self.results_box = VBox([DEFAULT_RESULTS_WIDGET])
        self.results = None

        # Together

        self.buttons = self.twitter_buttons + self.wikipedia_buttons

        search_box = VBox((twitter_box, wikipedia_box),
                          # layout=dict(max_width="600px")
                          )

        results_box = VBox((results_title, self.results_box))

        self.widget = VBox((title, search_box, results_box))

        self._reset_and_clear_tweets()

    def _disable_twitter_buttons(self):
        for button in self.twitter_buttons:
            button.disabled = True

    def _enable_twitter_buttons(self):
        for button in self.twitter_buttons:
            button.disabled = False

    def _disable_twitter_search_buttons(self):
        for button in self.twitter_search_buttons:
            button.disabled = True

    def _enable_twitter_search_buttons(self):
        for button in self.twitter_search_buttons:
            button.disabled = False

    def _disable_wikipedia_buttons(self):
        for button in self.wikipedia_buttons:
            button.disabled = True

    def _enable_wikipedia_buttons(self):
        for button in self.wikipedia_buttons:
            button.disabled = False

    def _disable_all_buttons(self):
        for button in self.buttons:
            button.disabled = True

    def _enable_all_buttons(self):
        for button in self.buttons:
            button.disabled = False

    def _reset_and_clear_tweets(self):
        self.query_field.value = ""
        self.number_of_tweets_field.value = \
            DEFAULT_NUMBER_OF_TWEETS
        self.results_box.children = [DEFAULT_RESULTS_WIDGET]
        self._disable_twitter_buttons()
        self.search_wikipedia_button.disabled = True
        self._reset_and_clear_wikipedia_results()

    def _reset_and_clear_wikipedia_results(self):
        self.term_frequency_scaling_parameter_field.value = \
            DEFAULT_TERM_FREQUENCY_SCALING_PARAMETER_VALUE
        self.document_length_scaling_parameter_field.value = \
            DEFAULT_DOCUMENT_LENGTH_SCALING_PARAMETER_VALUE
        self._clear_wikipedia_results()
        self.reset_and_clear_wikipedia_results_button.disabled = True

    def _clear_wikipedia_results(self):
        if isinstance(self.results, list):
            for result in self.results:
                if "articles" in result:
                    result["articles"].value = ""

    def _query_field_changed(self, notification):

        if notification.type == "change":

            query = notification.new

            if query:
                self.reset_and_clear_tweets_button.disabled = False
                self.search_tweets_button.disabled = False
            else:
                self.reset_and_clear_tweets_button.disabled = True
                self.search_tweets_button.disabled = True

            is_username, username = check_twitter_username(query)

            if is_username:
                self.load_tweets_from_user_button.disabled = False
            else:
                self.load_tweets_from_user_button.disabled = True

    def _number_of_tweets_field_changed(self, notification):

        if notification.type == "change":

            number_of_tweets = notification.new

            if number_of_tweets == DEFAULT_NUMBER_OF_TWEETS:
                self.reset_and_clear_tweets_button.disabled = True
            else:
                self.reset_and_clear_tweets_button.disabled = False

    def _term_frequency_scaling_parameter_field_changed(self, notification):

        if notification.type == "change":

            term_frequency_scaling_parameter = notification.new

            if term_frequency_scaling_parameter == \
                DEFAULT_TERM_FREQUENCY_SCALING_PARAMETER_VALUE:

                self.reset_and_clear_wikipedia_results_button.disabled = True
            else:
                self.reset_and_clear_wikipedia_results_button.disabled = False

    def _document_length_scaling_parameter_field_changed(self, notification):

        if notification.type == "change":

            document_length_scaling_parameter = notification.new

            if document_length_scaling_parameter == \
                DEFAULT_DOCUMENT_LENGTH_SCALING_PARAMETER_VALUE:

                self.reset_and_clear_wikipedia_results_button.disabled = True
            else:
                self.reset_and_clear_wikipedia_results_button.disabled = False

    def _twitter_search_button_pressed(self, button):

        query = self.query_field.value
        count = self.number_of_tweets_field.value

        if button is self.search_tweets_button:
            twitter_method = "search"
            query_string = f"Searching for recent tweets matching \"{query}\"..."

        elif button is self.load_tweets_from_user_button:
            twitter_method = "user_timeline"
            username = query.lstrip("@")
            query_string = f"Loading tweets from @{username}..."

        self.results_box.children = [HTML(f"<i>{query_string}</i>")]

        self._disable_all_buttons()

        self.controller.search_tweets(query=query,
                                      twitter_method=twitter_method,
                                      count=count)

        self._enable_twitter_buttons()

    def _reset_and_clear_tweets_button_pressed(self, button):
        self._reset_and_clear_tweets()

    def _reset_and_clear_wikipedia_results_button_pressed(self, button):
        self._reset_and_clear_wikipedia_results()

    def _search_wikipedia_button_pressed(self, button):

        k_1 = self.term_frequency_scaling_parameter_field.value
        b = self.document_length_scaling_parameter_field.value

        self._disable_all_buttons()

        self.controller.search_wikipedia(k_1=k_1, b=b)

        self._enable_all_buttons()

    def show_tweets(self, tweets):

        self.results = []

        for tweet in tweets:

            result = {
                "tweet": Output(layout=dict(width="50%")),
                "articles": HTML(layout=dict(width="50%"))
            }
            self.results.append(result)

            with result["tweet"]:
                display(IPHTML(tweet.as_html(hide_thread=True)))

        self.results_box.children = [
            HBox((result["tweet"], result["articles"]))
            for result in self.results
        ]

        display(
            IPHTML(
                '<script id="twitter-wjs" type="text/javascript" async defer src="//platform.twitter.com/widgets.js"></script>'
            ))
        self.search_wikipedia_button.disabled = False

    def show_articles_for_tweet_number(self, i, formatted_results):
        if isinstance(self.results, list) \
            and i < len(self.results) \
            and "articles" in self.results[i] \
            and isinstance(self.results[i]["articles"], HTML):
            self.results[i]["articles"].value = formatted_results
Exemple #6
0
class SubstrateTab(object):

    def __init__(self):
        
        self.output_dir = '.'

        # initial value
        self.field_index = 4
        # self.field_index = self.mcds_field.value + 4

        tab_height = '500px'
        constWidth = '180px'
        constWidth2 = '150px'
        tab_layout = Layout(width='900px',   # border='2px solid black',
                            height=tab_height, ) #overflow_y='scroll')

        max_frames = 253   # first time + 30240 / 120
        self.mcds_plot = interactive(self.plot_substrate, frame=(0, max_frames), continuous_update=False)  
        svg_plot_size = '700px'
        self.mcds_plot.layout.width = svg_plot_size
        self.mcds_plot.layout.height = svg_plot_size

        self.max_frames = BoundedIntText(
            min=0, max=99999, value=max_frames,
            description='Max frames',
            layout=Layout(width='160px'),
        )
        self.max_frames.observe(self.update_max_frames)

        self.field_min_max = {'oxygen': [0., 38.], 'glucose': [0.8, 1.], 'H+ ions': [0., 1.], 
                                'ECM': [0., 1.], 'NP1': [0., 1.], 'NP2': [0., 0.1]}
        # hacky I know, but make a dict that's got (key,value) reversed from the dict in the Dropdown below
        self.field_dict = {0:'oxygen', 1:'glucose', 2:'H+ ions', 3:'ECM', 4:'NP1', 5:'NP2'}
        self.mcds_field = Dropdown(
            options={'oxygen': 0, 'glucose': 1, 'H+ ions': 2, 'ECM': 3, 'NP1': 4, 'NP2': 5},
            value=0,
            #     description='Field',
            layout=Layout(width=constWidth)
        )
#        self.mcds_field.observe(self.mcds_field_cb)
        self.mcds_field.observe(self.mcds_field_changed_cb)

        # self.field_cmap = Text(
        #     value='viridis',
        #     description='Colormap',
        #     disabled=True,
        #     layout=Layout(width=constWidth),
        # )
        self.field_cmap = Dropdown(
            options=['viridis', 'jet', 'YlOrRd'],
            value='viridis',
            #     description='Field',
            layout=Layout(width=constWidth)
        )
        #self.field_cmap.observe(self.plot_substrate)
#        self.field_cmap.observe(self.plot_substrate)
        self.field_cmap.observe(self.mcds_field_cb)

        self.cmap_fixed = Checkbox(
            description='Fix',
            disabled=False,
            layout=Layout(width=constWidth2),
        )

        self.save_min_max= Button(
            description='Save', #style={'description_width': 'initial'},
            button_style='success',  # 'success', 'info', 'warning', 'danger' or ''
            tooltip='Save min/max for this substrate',
            disabled=True,
            layout=Layout(width='90px')
        )

        def save_min_max_cb(b):
#            field_name = self.mcds_field.options[]
#            field_name = next(key for key, value in self.mcds_field.options.items() if value == self.mcds_field.value)
            field_name = self.field_dict[self.mcds_field.value]
#            print(field_name)
#            self.field_min_max = {'oxygen': [0., 30.], 'glucose': [0., 1.], 'H+ ions': [0., 1.], 'ECM': [0., 1.], 'NP1': [0., 1.], 'NP2': [0., 1.]}
            self.field_min_max[field_name][0] = self.cmap_min.value
            self.field_min_max[field_name][1] = self.cmap_max.value
#            print(self.field_min_max)

        self.save_min_max.on_click(save_min_max_cb)

        self.cmap_min = FloatText(
            description='Min',
            value=0,
            step = 0.1,
            disabled=True,
            #layout=Layout(width=constWidth2),
        )
        self.cmap_min.observe(self.mcds_field_cb)

        self.cmap_max = FloatText(
            description='Max',
            value=38,
            step = 0.1,
            disabled=True,
            #layout=Layout(width=constWidth2),
        )
        self.cmap_max.observe(self.mcds_field_cb)

        def cmap_fixed_cb(b):
            if (self.cmap_fixed.value):
                self.cmap_min.disabled = False
                self.cmap_max.disabled = False
                self.save_min_max.disabled = False
            else:
                self.cmap_min.disabled = True
                self.cmap_max.disabled = True
                self.save_min_max.disabled = True
#            self.mcds_field_cb()

        self.cmap_fixed.observe(cmap_fixed_cb)

        field_cmap_row2 = HBox([self.field_cmap, self.cmap_fixed])

#        field_cmap_row3 = HBox([self.save_min_max, self.cmap_min, self.cmap_max])
        items_auto = [
            self.save_min_max, #layout=Layout(flex='3 1 auto', width='auto'),
            self.cmap_min, 
            self.cmap_max,  
         ]
        box_layout = Layout(display='flex',
                    flex_flow='row',
                    align_items='stretch',
                    width='80%')
        field_cmap_row3 = Box(children=items_auto, layout=box_layout)

#        field_cmap_row3 = Box([self.save_min_max, self.cmap_min, self.cmap_max])

        # mcds_tab = widgets.VBox([mcds_dir, mcds_plot, mcds_play], layout=tab_layout)
        mcds_params = VBox([self.mcds_field, field_cmap_row2, field_cmap_row3, self.max_frames])  # mcds_dir
#        mcds_params = VBox([self.mcds_field, field_cmap_row2, field_cmap_row3,])  # mcds_dir

        self.tab = HBox([mcds_params, self.mcds_plot], layout=tab_layout)
#        self.tab = HBox([mcds_params, self.mcds_plot])

    def update_max_frames(self,_b):
        self.mcds_plot.children[0].max = self.max_frames.value

    def mcds_field_changed_cb(self, b):
        self.field_index = self.mcds_field.value + 4

        field_name = self.field_dict[self.mcds_field.value]
#        print('mcds_field_cb: '+field_name)
        self.cmap_min.value = self.field_min_max[field_name][0]
        self.cmap_max.value = self.field_min_max[field_name][1]
        self.mcds_plot.update()

    def mcds_field_cb(self, b):
        #self.field_index = self.mcds_field.value
#        self.field_index = self.mcds_field.options.index(self.mcds_field.value) + 4
#        self.field_index = self.mcds_field.options[self.mcds_field.value]
        self.field_index = self.mcds_field.value + 4

        # field_name = self.mcds_field.options[self.mcds_field.value]
        # self.cmap_min.value = self.field_min_max[field_name][0]  # oxygen, etc
        # self.cmap_max.value = self.field_min_max[field_name][1]  # oxygen, etc

#        self.field_index = self.mcds_field.value + 4

#        print('field_index=',self.field_index)
        self.mcds_plot.update()

    def plot_substrate(self, frame):
        # global current_idx, axes_max, gFileId, field_index
        fname = "output%08d_microenvironment0.mat" % frame
        xml_fname = "output%08d.xml" % frame
        # fullname = output_dir_str + fname

#        fullname = fname
        full_fname = os.path.join(self.output_dir, fname)
        full_xml_fname = os.path.join(self.output_dir, xml_fname)
#        self.output_dir = '.'

#        if not os.path.isfile(fullname):
        if not os.path.isfile(full_fname):
#            print("File does not exist: ", full_fname)
            print("No: ", full_fname)
            return

#        tree = ET.parse(xml_fname)
        tree = ET.parse(full_xml_fname)
        xml_root = tree.getroot()
        mins= round(int(float(xml_root.find(".//current_time").text)))  # TODO: check units = mins
        hrs = mins/60.
        days = hrs/24.
        title_str = '%dd, %dh, %dm' % (int(days),(hrs%24), mins - (hrs*60))


        info_dict = {}
#        scipy.io.loadmat(fullname, info_dict)
        scipy.io.loadmat(full_fname, info_dict)
        M = info_dict['multiscale_microenvironment']
        #     global_field_index = int(mcds_field.value)
        #     print('plot_substrate: field_index =',field_index)
        f = M[self.field_index, :]   # 4=tumor cells field, 5=blood vessel density, 6=growth substrate
        # plt.clf()
        # my_plot = plt.imshow(f.reshape(400,400), cmap='jet', extent=[0,20, 0,20])
    
        fig = plt.figure(figsize=(7.2,6))  # this strange figsize results in a ~square contour plot
        #     fig.set_tight_layout(True)
        #     ax = plt.axes([0, 0.05, 0.9, 0.9 ]) #left, bottom, width, height
        #     ax = plt.axes([0, 0.0, 1, 1 ])
        #     cmap = plt.cm.viridis # Blues, YlOrBr, ...
        #     im = ax.imshow(f.reshape(100,100), interpolation='nearest', cmap=cmap, extent=[0,20, 0,20])
        #     ax.grid(False)

        N = int(math.sqrt(len(M[0,:])))
        grid2D = M[0, :].reshape(N,N)
        xvec = grid2D[0, :]

        num_contours = 15
#        levels = MaxNLocator(nbins=10).tick_values(vmin, vmax)
        levels = MaxNLocator(nbins=num_contours).tick_values(self.cmap_min.value, self.cmap_max.value)
        if (self.cmap_fixed.value):
            my_plot = plt.contourf(xvec, xvec, M[self.field_index, :].reshape(N,N), levels=levels, extend='both', cmap=self.field_cmap.value)
        else:    
#        my_plot = plt.contourf(xvec, xvec, M[self.field_index, :].reshape(N,N), num_contours, cmap=self.field_cmap.value)
            my_plot = plt.contourf(xvec, xvec, M[self.field_index, :].reshape(N,N), num_contours, cmap=self.field_cmap.value)

        plt.title(title_str)
        plt.colorbar(my_plot)
        axes_min = 0
        axes_max = 2000
Exemple #7
0
class SVGTab(object):

    def __init__(self):
        tab_height = '520px'
        tab_height = '600px'
        tab_layout = Layout(width='900px',   # border='2px solid black',
                            height=tab_height, overflow_y='scroll')

        self.output_dir = '.'

        constWidth = '180px'

#        self.fig = plt.figure(figsize=(6, 6))
        # self.fig = plt.figure(figsize=(7, 7))

        max_frames = 1
        self.svg_plot = interactive(self.plot_svg, frame=(0, max_frames), continuous_update=False)
        plot_size = '500px'
        plot_size = '600px'
        self.svg_plot.layout.width = plot_size
        self.svg_plot.layout.height = plot_size
        self.use_defaults = True
        self.show_nucleus = 0  # 0->False, 1->True in Checkbox!
        self.show_edge = 1  # 0->False, 1->True in Checkbox!
        self.scale_radius = 1.0
        self.axes_min = 0.0
        self.axes_max = 2000   # hmm, this can change (TODO?)

        self.max_frames = BoundedIntText(
            min=0, max=99999, value=max_frames,
            description='Max',
            layout=Layout(width='160px'),
#            layout=Layout(flex='1 1 auto', width='auto'),  #Layout(width='160px'),
        )
        self.max_frames.observe(self.update_max_frames)

        self.show_nucleus_checkbox= Checkbox(
            description='nucleus', value=False, disabled=False,
            layout=Layout(width=constWidth),
#            layout=Layout(flex='1 1 auto', width='auto'),  #Layout(width='160px'),
        )
        self.show_nucleus_checkbox.observe(self.show_nucleus_cb)

        self.show_edge_checkbox= Checkbox(
            description='edge', value=True, disabled=False,
            layout=Layout(width=constWidth),
#            layout=Layout(flex='1 1 auto', width='auto'),  #Layout(width='160px'),
        )
        self.show_edge_checkbox.observe(self.show_edge_cb)

#        row1 = HBox([Label('(select slider: drag or left/right arrows)'), 
#            self.max_frames, VBox([self.show_nucleus_checkbox, self.show_edge_checkbox])])
#            self.max_frames, self.show_nucleus_checkbox], layout=Layout(width='500px'))

#        self.tab = VBox([row1,self.svg_plot], layout=tab_layout)

        items_auto = [Label('select slider: drag or left/right arrows'), 
            self.max_frames, 
            self.show_nucleus_checkbox,  
            self.show_edge_checkbox, 
         ]
#row1 = HBox([Label('(select slider: drag or left/right arrows)'), 
#            max_frames, show_nucleus_checkbox, show_edge_checkbox], 
#            layout=Layout(width='800px'))
        box_layout = Layout(display='flex',
                    flex_flow='row',
                    align_items='stretch',
                    width='70%')
        row1 = Box(children=items_auto, layout=box_layout)

        if (hublib_flag):
            self.download_button = Download('svg.zip', style='warning', icon='cloud-download', 
                                            tooltip='You need to allow pop-ups in your browser', cb=self.download_cb)
            download_row = HBox([self.download_button.w, Label("Download all cell plots (browser must allow pop-ups).")])
            self.tab = VBox([row1, self.svg_plot, self.download_button.w], layout=tab_layout)
    #        self.tab = VBox([row1, self.svg_plot, self.download_button.w])
#            self.tab = VBox([row1, self.svg_plot, download_row])
        else:
            self.tab = VBox([row1, self.svg_plot])

    def update(self, rdir=''):
        # with debug_view:
        #     print("SVG: update rdir=", rdir)        

        if rdir:
            self.output_dir = rdir

        all_files = sorted(glob.glob(os.path.join(self.output_dir, 'snapshot*.svg')))
        if len(all_files) > 0:
            last_file = all_files[-1]
            self.max_frames.value = int(last_file[-12:-4])  # assumes naming scheme: "snapshot%08d.svg"

        # with debug_view:
        #     print("SVG: added %s files" % len(all_files))

    def download_cb(self):
        file_str = os.path.join(self.output_dir, '*.svg')
        # print('zip up all ',file_str)
        with zipfile.ZipFile('svg.zip', 'w') as myzip:
            for f in glob.glob(file_str):
                myzip.write(f, os.path.basename(f))   # 2nd arg avoids full filename path in the archive

    def show_nucleus_cb(self, b):
        global current_frame
        if (self.show_nucleus_checkbox.value):
            self.show_nucleus = 1
        else:
            self.show_nucleus = 0
#        self.plot_svg(self,current_frame)
        self.svg_plot.update()

    def show_edge_cb(self, b):
        if (self.show_edge_checkbox.value):
            self.show_edge = 1
        else:
            self.show_edge = 0
        self.svg_plot.update()


    def update_max_frames(self,_b):
        self.svg_plot.children[0].max = self.max_frames.value

    def plot_svg(self, frame):
        # global current_idx, axes_max
        global current_frame
        current_frame = frame
        fname = "snapshot%08d.svg" % frame
        full_fname = os.path.join(self.output_dir, fname)
        # with debug_view:
        #     print("plot_svg:", full_fname) 
        if not os.path.isfile(full_fname):
            print("Once output files are generated, click the slider.")   
            return

        xlist = deque()
        ylist = deque()
        rlist = deque()
        rgb_list = deque()

        #  print('\n---- ' + fname + ':')
#        tree = ET.parse(fname)
        tree = ET.parse(full_fname)
        root = tree.getroot()
        #  print('--- root.tag ---')
        #  print(root.tag)
        #  print('--- root.attrib ---')
        #  print(root.attrib)
        #  print('--- child.tag, child.attrib ---')
        numChildren = 0
        for child in root:
            #    print(child.tag, child.attrib)
            #    print("keys=",child.attrib.keys())
            if self.use_defaults and ('width' in child.attrib.keys()):
                self.axes_max = float(child.attrib['width'])
                # print("debug> found width --> axes_max =", axes_max)
            if child.text and "Current time" in child.text:
                svals = child.text.split()
                # title_str = "(" + str(current_idx) + ") Current time: " + svals[2] + "d, " + svals[4] + "h, " + svals[7] + "m"
                # title_str = "Current time: " + svals[2] + "d, " + svals[4] + "h, " + svals[7] + "m"
                title_str = svals[2] + "d, " + svals[4] + "h, " + svals[7] + "m"

            # print("width ",child.attrib['width'])
            # print('attrib=',child.attrib)
            # if (child.attrib['id'] == 'tissue'):
            if ('id' in child.attrib.keys()):
                # print('-------- found tissue!!')
                tissue_parent = child
                break

        # print('------ search tissue')
        cells_parent = None

        for child in tissue_parent:
            # print('attrib=',child.attrib)
            if (child.attrib['id'] == 'cells'):
                # print('-------- found cells, setting cells_parent')
                cells_parent = child
                break
            numChildren += 1

        num_cells = 0
        #  print('------ search cells')
        for child in cells_parent:
            #    print(child.tag, child.attrib)
            #    print('attrib=',child.attrib)
            for circle in child:  # two circles in each child: outer + nucleus
                #  circle.attrib={'cx': '1085.59','cy': '1225.24','fill': 'rgb(159,159,96)','r': '6.67717','stroke': 'rgb(159,159,96)','stroke-width': '0.5'}
                #      print('  --- cx,cy=',circle.attrib['cx'],circle.attrib['cy'])
                xval = float(circle.attrib['cx'])

                s = circle.attrib['fill']
                # print("s=",s)
                # print("type(s)=",type(s))
                if (s[0:3] == "rgb"):  # if an rgb string, e.g. "rgb(175,175,80)" 
                    rgb = list(map(int, s[4:-1].split(",")))  
                    rgb[:] = [x / 255. for x in rgb]
                else:     # otherwise, must be a color name
                    rgb_tuple = mplc.to_rgb(mplc.cnames[s])  # a tuple
                    rgb = [x for x in rgb_tuple]

                # test for bogus x,y locations (rwh TODO: use max of domain?)
                too_large_val = 10000.
                if (np.fabs(xval) > too_large_val):
                    print("bogus xval=", xval)
                    break
                yval = float(circle.attrib['cy'])
                if (np.fabs(yval) > too_large_val):
                    print("bogus xval=", xval)
                    break

                rval = float(circle.attrib['r'])
                # if (rgb[0] > rgb[1]):
                #     print(num_cells,rgb, rval)
                xlist.append(xval)
                ylist.append(yval)
                rlist.append(rval)
                rgb_list.append(rgb)

                # For .svg files with cells that *have* a nucleus, there will be a 2nd
                if (self.show_nucleus == 0):
                #if (not self.show_nucleus):
                    break

            num_cells += 1

            # if num_cells > 3:   # for debugging
            #   print(fname,':  num_cells= ',num_cells," --- debug exit.")
            #   sys.exit(1)
            #   break

            # print(fname,':  num_cells= ',num_cells)

        xvals = np.array(xlist)
        yvals = np.array(ylist)
        rvals = np.array(rlist)
        rgbs = np.array(rgb_list)
        # print("xvals[0:5]=",xvals[0:5])
        # print("rvals[0:5]=",rvals[0:5])
        # print("rvals.min, max=",rvals.min(),rvals.max())

        # rwh - is this where I change size of render window?? (YES - yipeee!)
        #   plt.figure(figsize=(6, 6))
        #   plt.cla()
        title_str += " (" + str(num_cells) + " agents)"
        #   plt.title(title_str)
        #   plt.xlim(axes_min,axes_max)
        #   plt.ylim(axes_min,axes_max)
        #   plt.scatter(xvals,yvals, s=rvals*scale_radius, c=rgbs)
#        self.fig = plt.figure(figsize=(6, 6))
        self.fig = plt.figure(figsize=(7, 7))

#        axx = plt.axes([0, 0.05, 0.9, 0.9])  # left, bottom, width, height
#        axx = fig.gca()
#        print('fig.dpi=',fig.dpi) # = 72

        #   im = ax.imshow(f.reshape(100,100), interpolation='nearest', cmap=cmap, extent=[0,20, 0,20])
        #   ax.xlim(axes_min,axes_max)
        #   ax.ylim(axes_min,axes_max)

        # convert radii to radii in pixels
#        ax2 = fig.gca()
        ax2 = self.fig.gca()
        N = len(xvals)
        rr_pix = (ax2.transData.transform(np.vstack([rvals, rvals]).T) -
                    ax2.transData.transform(np.vstack([np.zeros(N), np.zeros(N)]).T))
        rpix, _ = rr_pix.T

        markers_size = (144. * rpix / self.fig.dpi)**2   # = (2*rpix / fig.dpi * 72)**2
#        markers_size = (2*rpix / fig.dpi * 72)**2
        markers_size = markers_size/4000000.
        # print('max=',markers_size.max())

#        ax.scatter(xvals,yvals, s=rvals*self.scale_radius, c=rgbs)
#        axx.scatter(xvals,yvals, s=markers_size, c=rgbs)

#rwh - temp fix - Ah, error only occurs when "edges" is toggled on
        if (self.show_edge):
            try:
                plt.scatter(xvals,yvals, s=markers_size, c=rgbs, edgecolor='black', linewidth=0.5)
            except (ValueError):
                pass
        else:
            plt.scatter(xvals,yvals, s=markers_size, c=rgbs)

        plt.xlim(self.axes_min, self.axes_max)
        plt.ylim(self.axes_min, self.axes_max)
        #   ax.grid(False)
#        axx.set_title(title_str)
        plt.title(title_str)
Exemple #8
0
class MultiClassAnnotator:
    def __init__(self,
                 dataset_voc,
                 output_path_statistic,
                 name,
                 metrics,  # this is an array of the following ['occ', 'truncated', 'side', 'part']
                 show_name=True,
                 show_axis=False,
                 fig_size=(10, 10),
                 buttons_vertical=False,
                 image_display_function=None,
                 classes_to_annotate=None
                 ):

        if dataset_voc.annotations_gt is None:  # in case that dataset_voc has not been called
            dataset_voc.load()

        self.dataset_voc = dataset_voc
        self.metrics = metrics
        self.show_axis = show_axis
        self.name = name
        self.show_name = show_name
        if output_path_statistic is None:
            output_path_statistic = self.dataset_voc.dataset_root_param

        if classes_to_annotate is None:  # if classes_to_annotate is None, all the classes would be annotated
            self.classes_to_annotate = self.dataset_voc.objnames_all  # otherwise, the only the classes in the list

        self.output_path = output_path_statistic
        self.file_path = os.path.join(self.output_path, self.name + ".json")
        self.mapping, self.dataset = self.__create_results_dict(self.file_path, metrics)

        self.objects = dataset_voc.get_objects_index(self.classes_to_annotate)
        self.current_pos = 0

        self.max_pos = len(self.objects) - 1

        self.fig_size = fig_size
        self.buttons_vertical = buttons_vertical

        if image_display_function is None:
            self.image_display_function = self.__show_image
        else:
            self.image_display_function = image_display_function

        # create buttons
        self.previous_button = self.__create_button("Previous", (self.current_pos == 0), self.__on_previous_clicked)
        self.next_button = self.__create_button("Next", (self.current_pos == self.max_pos), self.__on_next_clicked)
        self.save_button = self.__create_button("Save", False, self.__on_save_clicked)
        self.save_function = self.__save_function  # save_function
        self.current_image = {}
        buttons = [self.previous_button, self.next_button]
        buttons.append(self.save_button)

        label_total = Label(value='/ {}'.format(len(self.objects)))
        self.text_index = BoundedIntText(value=1, min=1, max=len(self.objects))
        self.text_index.layout.width = '80px'
        self.text_index.layout.height = '35px'
        self.text_index.observe(self.__selected_index)
        self.out = Output()
        self.out.add_class(name)

        metrics_labels = self.dataset_voc.read_label_metrics_name_from_file()
        metrics_labels['truncated'] = {'0': 'False', '1': 'True'}
        self.metrics_labels = metrics_labels

        self.checkboxes = {}
        self.radiobuttons = {}

        output_layout = []
        for m_i, m_n in enumerate(self.metrics):
            if 'parts' == m_n:  # Special case
                continue

            if m_n in ['truncated', 'occ']:  # radiobutton
                self.radiobuttons[m_n] = RadioButtons(options=[i for i in metrics_labels[m_n].values()],
                                                      disabled=False,
                                                      indent=False)
            else:  # checkbox
                self.checkboxes[m_n] = [Checkbox(False, description='{}'.format(metrics_labels[m_n][i]),
                                                 indent=False) for i in metrics_labels[m_n].keys()]

        self.check_radio_boxes_layout = {}
        for cb_k, cb_i in self.checkboxes.items():
            for cb in cb_i:
                cb.layout.width = '180px'
                cb.observe(self.__checkbox_changed)
            html_title = HTML(value="<b>" + cb_k + "</b>")
            self.check_radio_boxes_layout[cb_k] = VBox(children=[cb for cb in cb_i])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[cb_k]]))

        for rb_k, rb_v in self.radiobuttons.items():
            rb_v.layout.width = '180px'
            rb_v.observe(self.__checkbox_changed)
            html_title = HTML(value="<b>" + rb_k + "</b>")
            self.check_radio_boxes_layout[rb_k] = VBox([rb_v])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[rb_k]]))

        #create an output for the future dynamic SIDES_PARTS attributes
        self.dynamic_output_for_parts = Output()
        html_title = HTML(value="<b>" + "Parts" + "</b>")
        output_layout.append(VBox([html_title, self.dynamic_output_for_parts]))

        self.all_widgets = VBox(children=
                                [HBox([self.text_index, label_total]),
                                 HBox(buttons),
                                 HBox(output_layout),
                                 self.out])

        ## loading js library to perform html screenshots
        j_code = """
                require.config({
                    paths: {
                        html2canvas: "https://html2canvas.hertzen.com/dist/html2canvas.min"
                    }
                });
            """
        display(Javascript(j_code))

    def __create_results_dict(self, file_path, cc):
        mapping = {}
        # mapping["categories_id"] = {}
        # mapping["categories_name"] = {}
        # mapping["objects"] = {}
        #
        # if not os.path.exists(file_path):
        #     dataset = {}
        #     dataset['categories'] = []
        #     dataset["images"] = []
        #     for index, c in enumerate(cc):
        #         category = {}
        #         category["supercategory"] = c
        #         category["name"] = c
        #         category["id"] = index
        #         dataset["categories"].append(category)
        # else:
        #     with open(file_path, 'r') as classification_file:
        #         dataset = json.load(classification_file)
        #     for index, img in enumerate(dataset['images']):
        #         mapping['images'][img["path"]] = index
        #
        # for index, c in enumerate(dataset['categories']):
        #     mapping['categories_id'][c["id"]] = c["name"]
        #     mapping['categories_name'][c["name"]] = c["id"]
        # index_categories = len(dataset['categories']) - 1
        #
        # for c in cc:
        #     if not c in mapping['categories_name'].keys():
        #         mapping['categories_id'][index_categories] = c
        #         mapping['categories_name'][c] = index_categories
        #         category = {}
        #         category["supercategory"] = c
        #         category["name"] = c
        #         category["id"] = index_categories
        #         dataset["categories"].append(category)
        #         index_categories += 1

        return {}, {}

    def __checkbox_changed(self, b):
        if b['owner'].value is None or b['name'] != 'value':
            return

        class_name = b['owner'].description
        value = b['owner'].value

        current_index = self.mapping["images"][self.objects[self.current_pos]]
        class_index = self.mapping["categories_name"][class_name]
        if not class_index in self.dataset["images"][current_index]["categories"] and value:
            self.dataset["images"][current_index]["categories"].append(class_index)
        if class_index in self.dataset["images"][current_index]["categories"] and not value:
            self.dataset["images"][current_index]["categories"].remove(class_index)

    def __create_button(self, description, disabled, function):
        button = Button(description=description)
        button.disabled = disabled
        button.on_click(function)
        return button

    def __show_image(self, image_record, obj_num):
        #   read img from path and show it
        path_img = os.path.join(self.output_path, 'JPEGImages', image_record['filename'])
        img = Image.open(path_img)
        if self.show_name:
            print(os.path.basename(path_img) + '. Class: {} [class_id={}]'.format(
                self.objects[self.current_pos]['class_name'],
                self.objects[self.current_pos]['class_id']))
        plt.figure(figsize=self.fig_size)

        if not self.show_axis:
            plt.axis('off')
        plt.imshow(img)

        # draw the bbox from the object onum
        ax = plt.gca()
        class_colors = cm.rainbow(np.linspace(0, 1, len(self.dataset_voc.objnames_all)))

        [bbox_x1, bbox_y1, bbox_x2, bbox_y2] = image_record['objects'][obj_num]['bbox']
        poly = [[bbox_x1, bbox_y1], [bbox_x1, bbox_y2], [bbox_x2, bbox_y2],
                [bbox_x2, bbox_y1]]
        np_poly = np.array(poly).reshape((4, 2))

        object_class_name = image_record['objects'][obj_num]['class']
        c = class_colors[self.dataset_voc.objnames_all.index(object_class_name)]

        # draws the bbox
        ax.add_patch(
            Polygon(np_poly, linestyle='-', facecolor=(c[0], c[1], c[2], 0.0),
                    edgecolor=(c[0], c[1], c[2], 1.0), linewidth=2))

        #  write the class name in bbox
        ax.text(x=bbox_x1, y=bbox_y1, s=object_class_name, color='white', fontsize=9, horizontalalignment='left',
                verticalalignment='top',
                bbox=dict(facecolor=(c[0], c[1], c[2], 0.5)))
        plt.show()

    def save_state(self):
        with open(self.file_path, 'w') as output_file:
            json.dump(self.dataset, output_file, indent=4)

    def __save_function(self, image_path):
        img_name = os.path.basename(image_path).split('.')[0]
        j_code = """
            require(["html2canvas"], function(html2canvas) {
                var element = $(".p-Widget.jupyter-widgets-output-area.output_wrapper.$it_name$")[0];
                console.log(element);
                 html2canvas(element).then(function (canvas) { 
                    var myImage = canvas.toDataURL(); 
                    var a = document.createElement("a"); 
                    a.href = myImage; 
                    a.download = "$img_name$.png"; 
                    a.click(); 
                    a.remove(); 
                });
            });
            """
        j_code = j_code.replace('$it_name$', self.name)
        j_code = j_code.replace('$img_name$', img_name)
        tmp_out = Output()
        with tmp_out:
            display(Javascript(j_code))
            tmp_out.clear_output()

    def __perform_action(self):
        # generates statistic saved in json file
        # print(self.objects[self.current_class])
        # if not self.objects[self.current_class][self.current_pos]['path'] in self.mapping["images"].keys():
        #     image = {}
        #     image["path"] = self.objects[self.current_class][self.current_pos]
        #     image["id"] = len(self.mapping["images"]) + 1
        #     image["categories"] = []
        #     self.dataset["images"].append(image)
        #     self.mapping["images"][image["path"]] = len(self.dataset["images"]) - 1
        # current_index = self.mapping["images"][self.objects[self.current_class][self.current_pos]]
        self.next_button.disabled = (self.current_pos == self.max_pos)
        self.previous_button.disabled = (self.current_pos == 0)

        current_class_id = self.objects[self.current_pos]['class_id']
        current_gt_id = self.objects[self.current_pos]['gt_id']

        # start to check for each type of metric
        if 'occ' in self.radiobuttons.keys():
            cb = self.radiobuttons['occ']
            rb_options = self.radiobuttons['occ'].options
            cb.unobserve(self.__checkbox_changed)
            cb.value = None #clear the current value
            if self.dataset_voc.annotations_gt['gt'][current_class_id]['details'][
                current_gt_id]:  # check if it is empty
                occ_level = self.dataset_voc.annotations_gt['gt'][current_class_id]['details'][current_gt_id][
                    'occ_level']
                cb.value = rb_options[occ_level - 1]
            cb.observe(self.__checkbox_changed)

        if 'truncated' in self.radiobuttons.keys(): #since this works for PASCAL VOC there's always a truncation value
            cb = self.radiobuttons['truncated']
            rb_options = self.radiobuttons['truncated'].options
            cb.unobserve(self.__checkbox_changed)
            cb.value = rb_options[
                int(self.dataset_voc.annotations_gt['gt'][current_class_id]['istrunc'][current_gt_id] == True)]
            cb.observe(self.__checkbox_changed)

        if 'views' in self.checkboxes.keys():
            for cb_i, cb in enumerate(self.checkboxes['views']):
                cb.unobserve(self.__checkbox_changed)
                cb.value = False #clear the value
                if self.dataset_voc.annotations_gt['gt'][current_class_id]['details'][current_gt_id]:
                    # check if it is empty
                    cb.value = bool(self.dataset_voc.annotations_gt['gt'][current_class_id]['details'][current_gt_id][
                                        'side_visible'][cb.description])
                cb.observe(self.__checkbox_changed)

        #need to create the output first for the buttons
        with self.dynamic_output_for_parts:
            self.dynamic_output_for_parts.clear_output()
            if self.objects[self.current_pos]['class_name'] in self.metrics_labels['parts']:
                self.cb_parts = [Checkbox(False, description='{}'.format(i), indent=False) for i in self.metrics_labels['parts'][self.objects[self.current_pos]['class_name']]]
            else:
                self.cb_parts = [HTML(value="No PARTS defined in Conf file")]
            display(VBox(children=[cb for cb in self.cb_parts]))

        with self.out:
            self.out.clear_output()
            image_record, obj_num = self.__get_image_record()
            self.image_display_function(image_record, obj_num)

        self.text_index.unobserve(self.__selected_index)
        self.text_index.value = self.current_pos + 1
        self.text_index.observe(self.__selected_index)

    def __get_image_record(self):
        current_class_id = self.objects[self.current_pos]['class_id']
        current_gt_id = self.objects[self.current_pos]['gt_id']

        obj_num = self.dataset_voc.annotations_gt['gt'][current_class_id]['onum'][current_gt_id]
        index_row = self.dataset_voc.annotations_gt['gt'][current_class_id]['rnum'][current_gt_id]
        r = self.dataset_voc.annotations_gt['rec'][index_row]
        return r, obj_num

    def __on_previous_clicked(self, b):
        self.save_state()
        self.current_pos -= 1
        self.__perform_action()

    def __on_next_clicked(self, b):
        self.save_state()
        self.current_pos += 1
        self.__perform_action()

    def __on_save_clicked(self, b):
        self.save_state()
        image_record, _ = self.__get_image_record()
        path_img = os.path.join(self.output_path, 'JPEGImages', image_record['filename'])
        self.save_function(path_img)

    def __selected_index(self, t):
        if t['owner'].value is None or t['name'] != 'value':
            return
        self.current_pos = t['new'] - 1
        self.__perform_action()

    def start_classification(self):
        if self.max_pos < self.current_pos:
            print("No available images")
            return
        display(self.all_widgets)
        self.__perform_action()

    def print_statistics(self):
        counter = defaultdict(int)
        for record in self.dataset["images"]:
            for c in record["categories"]:
                counter[c] += 1
        table = []
        for c in counter:
            table.append([self.mapping["categories_id"][c], counter[c]])
        table = sorted(table, key=lambda x: x[0])
        print(tabulate(table, headers=['Class name', 'Annotated images']))
Exemple #9
0
class PopulationsTab(object):
    def __init__(self):

        self.output_dir = '.'
        # self.output_dir = 'tmpdir'

        self.figsize_width_substrate = 15.0  # allow extra for colormap
        self.figsize_height_substrate = 8
        self.figsize_width_svg = 12.0
        self.figsize_height_svg = 12.0

        # self.fig = plt.figure(figsize=(7.2,6))  # this strange figsize results in a ~square contour plot

        self.first_time = True
        self.modulo = 1

        self.use_defaults = True

        self.svg_delta_t = 1
        self.substrate_delta_t = 1
        self.svg_frame = 1
        self.substrate_frame = 1

        self.customized_output_freq = False
        self.therapy_activation_time = 1000000
        self.max_svg_frame_pre_therapy = 1000000
        self.max_substrate_frame_pre_therapy = 1000000

        self.svg_xmin = 0

        # Probably don't want to hardwire these if we allow changing the domain size
        # self.svg_xrange = 2000
        # self.xmin = -1000.
        # self.xmax = 1000.
        # self.ymin = -1000.
        # self.ymax = 1000.
        # self.x_range = 2000.
        # self.y_range = 2000.

        self.show_nucleus = False
        self.show_edge = True

        # initial value
        self.field_index = 4
        # self.field_index = self.mcds_field.value + 4

        self.skip_cb = False

        # define dummy size of mesh (set in the tool's primary module)
        self.numx = 0
        self.numy = 0

        self.title_str = ''

        tab_height = '1200px'
        tab_height = '500px'
        constWidth = '180px'
        constWidth2 = '150px'
        tab_layout = Layout(
            width='900px',  # border='2px solid black',
            height=tab_height,
        )  #overflow_y='scroll')

        max_frames = 0
        # self.mcds_plot = interactive(self.plot_substrate, frame=(0, max_frames), continuous_update=False)
        # self.i_plot = interactive(self.plot_plots, frame=(0, max_frames), continuous_update=False)
        self.pop_plot = interactive(self.plot_celltypes,
                                    frame=(0, max_frames),
                                    continuous_update=False)

        self.max_frames = BoundedIntText(
            min=0,
            max=99999,
            value=max_frames,
            description='Max',
            layout=Layout(width='160px'),
            #            layout=Layout(flex='1 1 auto', width='auto'),  #Layout(width='160px'),
        )
        self.max_frames.observe(self.update_max_frames)

        help_label = Label('select slider: drag or left/right arrows')

        self.tab = VBox([self.max_frames, self.pop_plot])
        # self.tab = VBox([controls_box, self.debug_str, self.i_plot, download_row])

    def update_max_frames(self, _b):
        self.pop_plot.children[0].max = self.max_frames.value

    def update(self, rdir=''):
        # with debug_view:
        #     print("substrates: update rdir=", rdir)
        # print("substrates: update rdir=", rdir)

        if rdir:
            self.output_dir = rdir

        all_files = sorted(
            glob.glob(os.path.join(self.output_dir, 'snapshot*.svg')))
        if len(all_files) > 0:
            last_file = all_files[-1]
            self.max_frames.value = int(
                last_file[-12:-4])  # assumes naming scheme: "snapshot%08d.svg"

    def plot_celltypes(self, frame=None):

        if frame <= 0:
            return
        self.title_str = ''

        # if (self.substrates_toggle.value):
        # if (True):
        self.pop_counts = {}

        self.fig = plt.figure(figsize=(self.figsize_width_substrate,
                                       self.figsize_height_substrate))

        for i in range(0, frame):
            fname = "output%08d_cells_physicell.mat" % i
            full_fname = os.path.join(self.output_dir, fname)

            if not os.path.isfile(full_fname):
                print("Once output files are generated, click the slider."
                      )  # No:  output00000000_microenvironment0.mat
                return

            info_dict = {}
            scipy.io.loadmat(full_fname, info_dict)

            M = info_dict['cells'][5, :].astype(int)

            unique, counts = np.unique(M, return_counts=True)
            pop_size = dict(zip(unique, counts))

            for key, value in pop_size.items():

                if key not in self.pop_counts.keys():
                    if i == 0:
                        self.pop_counts[key] = [value]
                    else:
                        self.pop_counts[key] = [0] * i + [value]

                else:
                    self.pop_counts[key].append(value)

        config_file = "config.xml"

        cell_lines = {}
        if os.path.isfile(config_file):

            try:
                tree = ET.parse(config_file)
            except:
                print("Cannot parse", config_file, "- check it's XML syntax.")
                return

            root = tree.getroot()
            uep = root.find(
                './/cell_definitions')  # find unique entry point (uep)
            for child in uep.findall('cell_definition'):
                cell_lines[int(child.attrib["ID"])] = child.attrib["name"]
                # print(child.attrib['name'])

        ax = self.fig.add_subplot(111)
        t_data = []
        t_names = []
        for t_id, name in cell_lines.items():
            if t_id in self.pop_counts.keys():
                t_data.append(self.pop_counts[t_id])
                t_names.append(name)
                # t_data = [value for value in self.pop_counts.values()]
        ax.stackplot(range(0, frame), t_data, labels=t_names)
        ax.legend(labels=t_names,
                  loc='upper center',
                  bbox_to_anchor=(0.5, -0.05),
                  ncol=2)
Exemple #10
0
class SVGTab(object):

    #    myplot = None

    def __init__(self):
        tab_height = '520px'
        tab_layout = Layout(
            width='800px',  # border='2px solid black',
            height=tab_height,
            overflow_y='scroll')

        self.output_dir = '.'

        max_frames = 505  # first time + 30240 / 60
        self.svg_plot = interactive(self.plot_svg,
                                    frame=(0, max_frames),
                                    continuous_update=False)
        svg_plot_size = '500px'
        self.svg_plot.layout.width = svg_plot_size
        self.svg_plot.layout.height = svg_plot_size
        self.use_defaults = True

        self.show_nucleus = 0  # 0->False, 1->True in Checkbox!
        self.show_edge = 1  # 0->False, 1->True in Checkbox!
        self.scale_radius = 1.0
        self.axes_min = 0.0
        self.axes_max = 2000  # hmm, this can change (TODO?)
        #        self.tab = HBox([svg_plot], layout=tab_layout)

        self.max_frames = BoundedIntText(
            min=0,
            max=99999,
            value=max_frames,
            description='Max',
            layout=Layout(flex='1 1 auto',
                          width='auto'),  #Layout(width='160px'),
        )
        self.max_frames.observe(self.update_max_frames)

        self.show_nucleus_checkbox = Checkbox(
            description='nucleus',
            value=False,
            disabled=False,
            layout=Layout(flex='1 1 auto',
                          width='auto'),  #Layout(width='160px'),
        )
        self.show_nucleus_checkbox.observe(self.show_nucleus_cb)

        self.show_edge_checkbox = Checkbox(
            description='edge',
            value=True,
            disabled=False,
            layout=Layout(flex='1 1 auto',
                          width='auto'),  #Layout(width='160px'),
        )
        self.show_edge_checkbox.observe(self.show_edge_cb)

        #        row1 = HBox([Label('(select slider: drag or left/right arrows)'),
        #            self.max_frames, VBox([self.show_nucleus_checkbox, self.show_edge_checkbox])])
        #            self.max_frames, self.show_nucleus_checkbox], layout=Layout(width='500px'))

        #        self.tab = VBox([row1,self.svg_plot], layout=tab_layout)

        items_auto = [
            Label('(select slider: drag or left/right arrows)'),
            self.max_frames,
            self.show_nucleus_checkbox,
            self.show_edge_checkbox,
        ]
        #row1 = HBox([Label('(select slider: drag or left/right arrows)'),
        #            max_frames, show_nucleus_checkbox, show_edge_checkbox],
        #            layout=Layout(width='800px'))
        box_layout = Layout(display='flex',
                            flex_flow='row',
                            align_items='stretch',
                            width='90%')
        row1 = Box(children=items_auto, layout=box_layout)
        self.tab = VBox([row1, self.svg_plot], layout=tab_layout)

#       self.output_dir_str = os.getenv('RESULTSDIR') + "/pc4nanobio/"

    def show_nucleus_cb(self, b):
        global current_frame
        if (self.show_nucleus_checkbox.value):
            self.show_nucleus = 1
        else:
            self.show_nucleus = 0
#        self.plot_svg(self,current_frame)
        self.svg_plot.update()

    def show_edge_cb(self, b):
        if (self.show_edge_checkbox.value):
            self.show_edge = 1
        else:
            self.show_edge = 0
        self.svg_plot.update()

    def update_max_frames(self, _b):
        self.svg_plot.children[0].max = self.max_frames.value

    def plot_svg(self, frame):
        # global current_idx, axes_max
        # print('plot_svg: SVG=', SVG)
        global current_frame
        current_frame = frame
        fname = "snapshot%08d.svg" % frame

        #        fullname = self.output_dir_str + fname
        #        fullname = fname  # do this for nanoHUB! (data appears in root dir?)
        full_fname = os.path.join(self.output_dir, fname)
        if not os.path.isfile(full_fname):
            #            print("File does not exist: ", fname)
            #            print("File does not exist: ", full_fname)
            print("No: ", full_fname)
            return

        xlist = deque()
        ylist = deque()
        rlist = deque()
        rgb_list = deque()

        #  print('\n---- ' + fname + ':')
        #        tree = ET.parse(fname)
        tree = ET.parse(full_fname)
        root = tree.getroot()
        #  print('--- root.tag ---')
        #  print(root.tag)
        #  print('--- root.attrib ---')
        #  print(root.attrib)
        #  print('--- child.tag, child.attrib ---')
        numChildren = 0
        for child in root:
            #    print(child.tag, child.attrib)
            #    print("keys=",child.attrib.keys())
            if self.use_defaults and ('width' in child.attrib.keys()):
                self.axes_max = float(child.attrib['width'])
                # print("debug> found width --> axes_max =", axes_max)
            if child.text and "Current time" in child.text:
                svals = child.text.split()
                # title_str = "(" + str(current_idx) + ") Current time: " + svals[2] + "d, " + svals[4] + "h, " + svals[7] + "m"
                # title_str = "Current time: " + svals[2] + "d, " + svals[4] + "h, " + svals[7] + "m"
                title_str = svals[2] + "d, " + svals[4] + "h, " + svals[7] + "m"

            # print("width ",child.attrib['width'])
            # print('attrib=',child.attrib)
            # if (child.attrib['id'] == 'tissue'):
            if ('id' in child.attrib.keys()):
                # print('-------- found tissue!!')
                tissue_parent = child
                break

        # print('------ search tissue')
        cells_parent = None

        for child in tissue_parent:
            # print('attrib=',child.attrib)
            if (child.attrib['id'] == 'cells'):
                # print('-------- found cells, setting cells_parent')
                cells_parent = child
                break
            numChildren += 1

        num_cells = 0
        #  print('------ search cells')
        for child in cells_parent:
            #    print(child.tag, child.attrib)
            #    print('attrib=',child.attrib)
            for circle in child:  # two circles in each child: outer + nucleus
                #  circle.attrib={'cx': '1085.59','cy': '1225.24','fill': 'rgb(159,159,96)','r': '6.67717','stroke': 'rgb(159,159,96)','stroke-width': '0.5'}
                #      print('  --- cx,cy=',circle.attrib['cx'],circle.attrib['cy'])
                xval = float(circle.attrib['cx'])

                s = circle.attrib['fill']
                # print("s=",s)
                # print("type(s)=",type(s))
                if (s[0:3] == "rgb"
                    ):  # if an rgb string, e.g. "rgb(175,175,80)"
                    rgb = list(map(int, s[4:-1].split(",")))
                    rgb[:] = [x / 255. for x in rgb]
                else:  # otherwise, must be a color name
                    rgb_tuple = mplc.to_rgb(mplc.cnames[s])  # a tuple
                    rgb = [x for x in rgb_tuple]

                # test for bogus x,y locations (rwh TODO: use max of domain?)
                too_large_val = 10000.
                if (np.fabs(xval) > too_large_val):
                    print("bogus xval=", xval)
                    break
                yval = float(circle.attrib['cy'])
                if (np.fabs(yval) > too_large_val):
                    print("bogus xval=", xval)
                    break

                rval = float(circle.attrib['r'])
                # if (rgb[0] > rgb[1]):
                #     print(num_cells,rgb, rval)
                xlist.append(xval)
                ylist.append(yval)
                rlist.append(rval)
                rgb_list.append(rgb)

                # For .svg files with cells that *have* a nucleus, there will be a 2nd
                if (self.show_nucleus == 0):
                    #if (not self.show_nucleus):
                    break

            num_cells += 1

            # if num_cells > 3:   # for debugging
            #   print(fname,':  num_cells= ',num_cells," --- debug exit.")
            #   sys.exit(1)
            #   break

            # print(fname,':  num_cells= ',num_cells)

        xvals = np.array(xlist)
        yvals = np.array(ylist)
        rvals = np.array(rlist)
        rgbs = np.array(rgb_list)
        # print("xvals[0:5]=",xvals[0:5])
        # print("rvals[0:5]=",rvals[0:5])
        # print("rvals.min, max=",rvals.min(),rvals.max())

        # rwh - is this where I change size of render window?? (YES - yipeee!)
        #   plt.figure(figsize=(6, 6))
        #   plt.cla()
        title_str += " (" + str(num_cells) + " agents)"
        #   plt.title(title_str)
        #   plt.xlim(axes_min,axes_max)
        #   plt.ylim(axes_min,axes_max)
        #   plt.scatter(xvals,yvals, s=rvals*scale_radius, c=rgbs)

        fig = plt.figure(figsize=(6, 6))
        #        axx = plt.axes([0, 0.05, 0.9, 0.9])  # left, bottom, width, height
        #        axx = fig.gca()
        #        print('fig.dpi=',fig.dpi) # = 72

        #   im = ax.imshow(f.reshape(100,100), interpolation='nearest', cmap=cmap, extent=[0,20, 0,20])
        #   ax.xlim(axes_min,axes_max)
        #   ax.ylim(axes_min,axes_max)

        # convert radii to radii in pixels
        ax2 = fig.gca()
        N = len(xvals)
        rr_pix = (ax2.transData.transform(np.vstack([rvals, rvals]).T) -
                  ax2.transData.transform(
                      np.vstack([np.zeros(N), np.zeros(N)]).T))
        rpix, _ = rr_pix.T

        markers_size = (144. * rpix /
                        fig.dpi)**2  # = (2*rpix / fig.dpi * 72)**2
        #        markers_size = (2*rpix / fig.dpi * 72)**2
        markers_size = markers_size / 4000000.
        #        print('max=',markers_size.max())

        #        ax.scatter(xvals,yvals, s=rvals*self.scale_radius, c=rgbs)
        #        axx.scatter(xvals,yvals, s=markers_size, c=rgbs)
        if (self.show_edge):
            plt.scatter(xvals,
                        yvals,
                        s=markers_size,
                        c=rgbs,
                        edgecolor='black',
                        linewidth='0.5')
        else:
            plt.scatter(xvals, yvals, s=markers_size, c=rgbs)
        plt.xlim(self.axes_min, self.axes_max)
        plt.ylim(self.axes_min, self.axes_max)
        #   ax.grid(False)
        #        axx.set_title(title_str)
        plt.title(title_str)
Exemple #11
0
class Annotator:

    def __init__(self,
                 dataset,
                 metrics_and_values,  # properties{ name: ([possible values], just_one_value = True)}
                 output_path=None,
                 show_name=True,
                 show_axis=False,
                 fig_size=(10, 10),
                 buttons_vertical=False,
                 image_display_function=None,
                 classes_to_annotate=None
                 ):

        self.dataset_orig = dataset
        self.metrics_and_values = metrics_and_values
        self.show_axis = show_axis
        self.name = (self.dataset_orig.dataset_root_param.split('/')[-1]).split('.')[0]  # get_original_file_name
        self.show_name = show_name
        if output_path is None:
            splitted_array = self.dataset_orig.dataset_root_param.split('/')
            n = len(splitted_array)
            self.output_directory = os.path.join(*(splitted_array[0:n - 1]))
        else:
            self.output_directory = output_path

        if classes_to_annotate is None:  # if classes_to_annotate is None, all the classes would be annotated
            self.classes_to_annotate = self.dataset_orig.get_categories_names()  # otherwise, the only the classes in the list

        self.file_path_for_json = os.path.join("/", self.output_directory, self.name + "_ANNOTATED.json")
        print("New dataset with meta_annotations {}".format(self.file_path_for_json))
        self.mapping = None
        self.objects = self.dataset_orig.get_annotations_from_class_list(self.classes_to_annotate)
        self.max_pos = len(self.objects) - 1
        self.current_pos = 0
        self.mapping, self.dataset_annotated = self.__create_results_dict(self.file_path_for_json)

        self.fig_size = fig_size
        self.buttons_vertical = buttons_vertical

        if image_display_function is None:
            self.image_display_function = self.__show_image
        else:
            self.image_display_function = image_display_function

        # create buttons
        self.previous_button = self.__create_button("Previous", (self.current_pos == 0), self.__on_previous_clicked)
        self.next_button = self.__create_button("Next", (self.current_pos == self.max_pos), self.__on_next_clicked)
        self.save_button = self.__create_button("Download", False, self.__on_save_clicked)
        self.save_function = self.__save_function  # save_function
        self.current_image = {}
        buttons = [self.previous_button, self.next_button]
        buttons.append(self.save_button)

        label_total = Label(value='/ {}'.format(len(self.objects)))
        self.text_index = BoundedIntText(value=1, min=1, max=len(self.objects))
        self.text_index.layout.width = '80px'
        self.text_index.layout.height = '35px'
        self.text_index.observe(self.__selected_index)
        self.out = Output()
        self.out.add_class("my_canvas_class")

        self.checkboxes = {}
        self.radiobuttons = {}
        self.bounded_text = {}

        output_layout = []
        for k_name, v in self.metrics_and_values.items():
            if MetaPropertiesTypes.META_PROP_UNIQUE == v[1]:  # radiobutton
                self.radiobuttons[k_name] = RadioButtons(name=k_name, options=v[0],
                                                         disabled=False,
                                                         indent=False)
            elif MetaPropertiesTypes.META_PROP_COMPOUND == v[1]:  # checkbox
                self.checkboxes[k_name] = [Checkbox(False, description='{}'.format(prop_name),
                                                    indent=False, name=k_name) for prop_name in v[0]]
            elif MetaPropertiesTypes.META_PROP_CONTINUE == v[1]:
                self.bounded_text[k_name] = BoundedFloatText(value=v[0][0], min=v[0][0], max=v[0][1])

        self.check_radio_boxes_layout = {}

        for rb_k, rb_v in self.radiobuttons.items():
            rb_v.layout.width = '180px'
            rb_v.observe(self.__checkbox_changed)
            rb_v.add_class(rb_k)
            html_title = HTML(value="<b>" + rb_k + "</b>")
            self.check_radio_boxes_layout[rb_k] = VBox([rb_v])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[rb_k]]))

        for cb_k, cb_i in self.checkboxes.items():
            for cb in cb_i:
                cb.layout.width = '180px'
                cb.observe(self.__checkbox_changed)
                cb.add_class(cb_k)
            html_title = HTML(value="<b>" + cb_k + "</b>")
            self.check_radio_boxes_layout[cb_k] = VBox(children=[cb for cb in cb_i])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[cb_k]]))

        for bf_k, bf in self.bounded_text.items():
            bf.layout.width = '80px'
            bf.layout.height = '35px'
            bf.observe(self.__checkbox_changed)
            bf.add_class(bf_k)
            html_title = HTML(value="<b>" + bf_k + "</b>")
            self.check_radio_boxes_layout[bf_k] = VBox([bf])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[bf_k]]))

        self.all_widgets = VBox(children=
                                [HBox([self.text_index, label_total]),
                                 HBox(buttons),
                                 HBox([self.out,
                                 VBox(output_layout)])])

        ## loading js library to perform html screenshots
        j_code = """
                require.config({
                    paths: {
                        html2canvas: "https://html2canvas.hertzen.com/dist/html2canvas.min"
                    }
                });
            """
        display(Javascript(j_code))

    def __add_annotation_to_mapping(self, ann):
        cat_name = self.dataset_orig.get_category_name_from_id(ann['category_id'])
        if ann['id'] not in self.mapping['annotated_ids']:  # o nly adds category counter if it was not annotate
            self.mapping['categories_counter'][cat_name] += 1

        self.mapping['annotated_ids'].add(ann['id'])

    def __create_results_dict(self, file_path):
        mapping = {}
        mapping["annotated_ids"] = set()  # key: object_id from dataset, values=[(annotations done)]
        mapping["categories_counter"] = dict.fromkeys([c for c in self.dataset_orig.get_categories_names()], 0)
        self.mapping = mapping

        if not os.path.exists("/" + file_path): #it does exist __ANNOTATED in the output directory
            with open(self.dataset_orig.dataset_root_param, 'r') as input_json_file:
                dataset_annotated = json.load(input_json_file)
                input_json_file.close()
            #take the same metaproperties already in file if it's not empty (like a new file)
            meta_prop = dataset_annotated['meta_properties'] if 'meta_properties' in dataset_annotated.keys() else []

            # adds the new annotations categories to dataset if it doesn't exist
            for k_name, v in self.metrics_and_values.items():
                new_mp_to_append = {
                    "name": k_name,
                    "type": v[1].value,
                    "values": [p for p in v[0]].sort()
                }

                names_prop_in_file = {m_p['name']: m_p for m_i, m_p in enumerate(dataset_annotated[
                                                                                     'meta_properties'])} if 'meta_properties' in dataset_annotated.keys() else None

                if 'meta_properties' not in dataset_annotated.keys():  # it is a new file
                    meta_prop.append(new_mp_to_append)
                    dataset_annotated['meta_properties'] = []

                elif names_prop_in_file is not None and k_name not in names_prop_in_file.keys():
                    # if there is a property with the same in meta_properties, it must be the same structure as the one proposed
                    meta_prop.append(new_mp_to_append)
                    self.__update_annotation_counter_and_current_pos(dataset_annotated)

                elif names_prop_in_file is not None and k_name in names_prop_in_file.keys() and \
                        names_prop_in_file[k_name] == new_mp_to_append:
                    #we don't append because it's already there
                    self.__update_annotation_counter_and_current_pos(dataset_annotated)

                else:
                    raise NameError("An annotation with the same name {} "
                                    "already exist in dataset {}, and it has different structure. Check properties.".format(
                        k_name, self.dataset_orig.dataset_root_param))

                #if k_name is in name_props_in_file and it's the same structure. No update is done.
            dataset_annotated['meta_properties'] = dataset_annotated['meta_properties'] + meta_prop
        else:
            with open(file_path, 'r') as classification_file:
                dataset_annotated = json.load(classification_file)
                classification_file.close()

            self.__update_annotation_counter_and_current_pos(dataset_annotated)
        return mapping, dataset_annotated

    def __update_annotation_counter_and_current_pos(self, dataset_annotated):
        prop_names = set(self.metrics_and_values.keys())
        last_ann_id = self.objects[0]['id']
        for ann in dataset_annotated['annotations']:
            if prop_names.issubset(
                    set(ann.keys())):  # we consider that it was annotated when all the props are subset of keys
                self.__add_annotation_to_mapping(ann)
                last_ann_id = ann['id']  # to get the last index in self.object so we can update the current pos in the last one anotated
            else:
                break
        self.current_pos = next(i for i, a in enumerate(self.objects) if a['id'] == last_ann_id)

    def __checkbox_changed(self, b):
        if b['owner'].value is None or b['name'] != 'value':
            return

        class_name = b['owner'].description
        value = b['owner'].value
        annotation_name = b['owner']._dom_classes[0]

        ann_id = self.objects[self.current_pos]['id']
        # image_id = self.objects[self.current_pos]['image_id']
        for ann in self.dataset_annotated['annotations']:
            if ann['id'] == ann_id:
                break
        if self.metrics_and_values[annotation_name][1] in [MetaPropertiesTypes.META_PROP_COMPOUND]:
            if annotation_name not in ann.keys():
                ann[annotation_name] = {p: 0 for p in self.metrics_and_values[annotation_name][0]}
            ann[annotation_name][class_name] = int(value)
        else:  # UNIQUE VALUE
            ann[annotation_name] = value

    def __create_button(self, description, disabled, function):
        button = Button(description=description)
        button.disabled = disabled
        button.on_click(function)
        return button

    def __show_image(self, image_record, ann_key):
        #   read img from path and show it
        path_img = os.path.join(self.dataset_orig.images_abs_path, image_record['file_name'])
        img = Image.open(path_img)
        if self.show_name:
            print(os.path.basename(path_img) + '. Class: {} [class_id={}]'.format(
                self.dataset_orig.get_category_name_from_id(self.objects[self.current_pos]['category_id']),
                self.objects[self.current_pos]['category_id']))
        plt.figure(figsize=self.fig_size)

        if not self.show_axis:
            plt.axis('off')
        plt.imshow(img)

        # draw the bbox from the object onum
        ax = plt.gca()
        class_colors = cm.rainbow(np.linspace(0, 1, len(self.dataset_orig.get_categories_names())))

        annotation = self.dataset_orig.coco_lib.anns[ann_key]
        object_class_name = self.dataset_orig.get_category_name_from_id(annotation['category_id'])
        c = class_colors[self.dataset_orig.get_categories_names().index(object_class_name)]
        if not self.dataset_orig.is_segmentation:
            [bbox_x1, bbox_y1, diff_x, diff_y] = annotation['bbox']
            bbox_x2 = bbox_x1 + diff_x
            bbox_y2 = bbox_y1 + diff_y
            poly = [[bbox_x1, bbox_y1], [bbox_x1, bbox_y2], [bbox_x2, bbox_y2],
                    [bbox_x2, bbox_y1]]
            np_poly = np.array(poly).reshape((4, 2))

            # draws the bbox
            ax.add_patch(
                Polygon(np_poly, linestyle='-', facecolor=(c[0], c[1], c[2], 0.0),
                        edgecolor=(c[0], c[1], c[2], 1.0), linewidth=2))
        else:
            seg_points = annotation['segmentation']
            for pol in seg_points:
                poly = [[float(pol[i]), float(pol[i + 1])] for i in range(0, len(pol), 2)]
                np_poly = np.array(poly)  # .reshape((len(pol), 2))
                ax.add_patch(
                    Polygon(np_poly, linestyle='-', facecolor=(c[0], c[1], c[2], 0.25),
                            edgecolor=(c[0], c[1], c[2], 1.0), linewidth=2))
            # set the first XY point for printing the text
            bbox_x1 = seg_points[0][0];
            bbox_y1 = seg_points[0][1]

        #  write the class name in bbox
        ax.text(x=bbox_x1, y=bbox_y1, s=object_class_name, color='white', fontsize=9, horizontalalignment='left',
                verticalalignment='top',
                bbox=dict(facecolor=(c[0], c[1], c[2], 0.5)))
        plt.show()

    def save_state(self):
        w = SafeWriter(os.path.join(self.file_path_for_json), "w")
        w.write(json.dumps(self.dataset_annotated))
        w.close()
        self.__add_annotation_to_mapping(next(ann_dat for ann_dat in self.dataset_annotated['annotations']
                                              if ann_dat['id'] == self.objects[self.current_pos]['id']))

    def __save_function(self, image_path):
        img_name = os.path.basename(image_path).split('.')[0]
        j_code = """
            require(["html2canvas"], function(html2canvas) {
                var element = $(".p-Widget.jupyter-widgets-output-area.output_wrapper.$it_name$")[0];
                console.log(element);
                 html2canvas(element).then(function (canvas) { 
                    var myImage = canvas.toDataURL(); 
                    var a = document.createElement("a"); 
                    a.href = myImage; 
                    a.download = "$img_name$.png"; 
                    a.click(); 
                    a.remove(); 
                });
            });
            """
        j_code = j_code.replace('$it_name$', "my_canvas_class")
        j_code = j_code.replace('$img_name$', img_name)
        tmp_out = Output()
        with tmp_out:
            display(Javascript(j_code))
            tmp_out.clear_output()

    def __perform_action(self):
        self.next_button.disabled = (self.current_pos == self.max_pos)
        self.previous_button.disabled = (self.current_pos == 0)

        current_class_id = self.objects[self.current_pos]['category_id']
        current_ann = next(
            a for a in self.dataset_annotated['annotations'] if a['id'] == self.objects[self.current_pos]['id'])
        # current_gt_id = self.objects[self.current_pos]['gt_id']

        for m_k, m_v in self.metrics_and_values.items():
            if m_v[1] == MetaPropertiesTypes.META_PROP_UNIQUE:  # radiobutton
                self.radiobuttons[m_k].unobserve(self.__checkbox_changed)
                self.radiobuttons[m_k].value = current_ann[m_k] if m_k in current_ann.keys() else None
                self.radiobuttons[m_k].observe(self.__checkbox_changed)
            elif m_v[1] == MetaPropertiesTypes.META_PROP_COMPOUND:  # checkbox
                for cb_i, cb_v in enumerate(self.checkboxes[m_k]):
                    cb_v.unobserve(self.__checkbox_changed)
                    cb_v.value = bool(current_ann[m_k][cb_v.description]) if m_k in current_ann.keys() else False
                    cb_v.observe(self.__checkbox_changed)
            elif m_v[1] == MetaPropertiesTypes.META_PROP_CONTINUE:  # textbound
                self.bounded_text[m_k].unobserve(self.__checkbox_changed)
                self.bounded_text[m_k].value = float(current_ann[m_k]) if m_k in current_ann.keys() else \
                    self.bounded_text[m_k].min
                self.bounded_text[m_k].observe(self.__checkbox_changed)

        with self.out:
            self.out.clear_output()
            image_record, ann_key = self.__get_image_record()
            self.image_display_function(image_record, ann_key)

        self.text_index.unobserve(self.__selected_index)
        self.text_index.value = self.current_pos + 1
        self.text_index.observe(self.__selected_index)

    def __get_image_record(self):
        # current_class_id = self.objects[self.current_pos]['category_id']
        current_image_id = self.objects[self.current_pos]['image_id']

        ann_key = self.objects[self.current_pos]['id']
        img_record = self.dataset_orig.coco_lib.imgs[current_image_id]
        return img_record, ann_key

    def __on_previous_clicked(self, b):
        self.save_state()
        self.current_pos -= 1
        self.__perform_action()

    def __on_next_clicked(self, b):
        self.save_state()
        self.current_pos += 1
        self.__perform_action()

    def __on_save_clicked(self, b):
        self.save_state()
        image_record, _ = self.__get_image_record()
        path_img = os.path.join(self.output_directory, 'JPEGImages', image_record['file_name'])
        self.save_function(path_img)

    def __selected_index(self, t):
        if t['owner'].value is None or t['name'] != 'value':
            return
        self.current_pos = t['new'] - 1
        self.__perform_action()

    def start_annotation(self):
        if self.max_pos < self.current_pos:
            print("No available images")
            return
        display(self.all_widgets)
        self.__perform_action()

    def print_statistics(self):
        table = []
        total = 0
        for c_k, c_number in self.mapping["categories_counter"].items():
            table.append([c_k, c_number])
            total += c_number
        table = sorted(table, key=lambda x: x[0])
        table.append(['Total', '{}/{}'.format(total, len(self.objects))])
        print(tabulate(table, headers=['Class name', 'Annotated objects']))
Exemple #12
0
class ImBox(VBox):
    """Widget for inspecting images that contain bounding boxes."""
    def __init__(self,
                 df: pd.DataFrame,
                 box_col: str = 'box',
                 img_col: str = 'image',
                 text_cols: Union[str, List[str]] = None,
                 text_fmts: Union[Callable, List[Callable]] = None,
                 style_col: str = None):
        """
        :param pd.DataFrame df: `DataFrame` with images and boxes
        :param str box_col: column in the dataframe that contains boxes
        :param str img_col: column in the dataframe that contains image paths
        :param Union[str, List[str]] text_cols: (optional) the column(s) in the
        dataframe to use for creating the text that is shown on top of a box.
        When multiple columns are give, the text will be created by a
        comma-separated list of the contents of the given columns.
        :param Unions[Callable, List[Callable]] text_fmts: (optional) a
        callable, or list of callables, that takes the corresponding value from
        the `text_cols` column(s) as an input and returns the string to print
        for that value.
        :param str style_col: the column containing a dict of style attributes.
        Available attributes are:
            - `stroke_width`: the stroke width of a box (default 2)
            - `stroke_color`: the stroke color of a box (default 'red')
            - `fill_color`: the fill color of a box (default  '#00000000')
            - `hover_fill`: the fill color of a box when it is hovered on
              (default '#00000088')
            - `hover_stroke`: the stroke color of a box when it is hovered on
              (default 'blue')
            - `active_fill`: the fill color of a box when it is clicked on
              (default '#ffffff22')
            - `active_stroke`: the stroke color of a box when it is clicked on
              (default 'green')
            - `font_family`: the font family to use for box text (default
            'arial'). NOTE: exported text will always be Arial.
            - `font_size`: the font size in points (default 10)
        """
        if text_cols is None:
            text_cols = []
        if isinstance(text_cols, str):
            text_cols = [text_cols]
        if text_fmts is None:
            text_fmts = [None] * len(text_cols)
        if isinstance(text_fmts, Callable):
            text_fmts = [text_fmts]
        self.text_cols = text_cols
        self.text_fmts = text_fmts

        df2 = df.copy()

        def row2text(row):
            txts = row[text_cols]
            return ', '.join([
                fmt(txt) if fmt is not None else str(txt)
                for txt, fmt in zip(txts, self.text_fmts)
            ])

        if style_col is None:
            style_col = '_dfim_style'
            df2[style_col] = [DEFAULT_STYLE] * len(df2)
        else:
            df2[style_col] = df2[style_col].apply(lambda s: {
                k: s[k] if k in s else DEFAULT_STYLE[k]
                for k in DEFAULT_STYLE
            })

        df2['box_text'] = df2.apply(lambda row: row2text(row), axis=1)
        df2['box_dict'] = df2.apply(
            lambda row: dict(index=row.name,
                             box=row[box_col],
                             text=row['box_text'],
                             style=row[style_col])
            if (box_col in row.index and row[box_col] is not None) else None,
            axis=1)

        self.df_img = df2.groupby(img_col).agg(list).reset_index()
        self.df = df
        self.img_col = img_col
        self.box_col = box_col

        # SELECTION widget
        self.idx_wgt = BoundedIntText(value=0,
                                      min=0,
                                      max=len(self.df_img) - 1,
                                      step=1,
                                      description='Choose index',
                                      disabled=False)
        self.drop_wgt = Dropdown(options=self.df_img[img_col],
                                 description='or image',
                                 value=None,
                                 disabled=False)
        self.drop_wgt.observe(self.drop_changed, names='value')
        self.idx_wgt.observe(self.idx_changed, names='value')
        self.imsel_wgt = VBox([self.idx_wgt, self.drop_wgt])
        self.imsel_wgt.layout = Layout(margin='auto')

        # IMAGE PANE
        self.img_title = HTML(placeholder='(Image path)')
        self.img_title.layout = Layout(margin='auto')
        self.imbox_wgt = ImBoxWidget()
        self.imbox_wgt.layout = Layout(margin='1em auto')
        self.imbox_wgt.observe(self.box_changed, names='active_box')
        self.imbox_wgt.observe(self.img_changed, names='img')

        # DETAILS PANE
        self.crop_wgt = CropBoxWidget()
        self.crop_wgt.layout = Layout(margin='0 1em')
        self.detail_wgt = DetailsWidget()
        self.detail_wgt.layout = Layout(margin='auto')
        self.detail_pane = HBox([self.crop_wgt, self.detail_wgt])
        self.detail_pane.layout = Layout(margin='1em auto')

        # PLAY widget
        self.play_btns = Play(interval=100,
                              value=0,
                              min=0,
                              max=len(self.df_img) - 1,
                              step=1,
                              description="Play",
                              disabled=False)
        self.play_slider = widgets.IntSlider(value=0,
                                             min=0,
                                             max=len(self.df_img) - 1,
                                             step=1)
        widgets.jslink((self.play_btns, 'value'), (self.idx_wgt, 'value'))
        widgets.jslink((self.play_btns, 'value'), (self.play_slider, 'value'))

        self.play_wgt = widgets.HBox([self.play_btns, self.play_slider])
        self.play_wgt.layout = Layout(margin='auto')

        # IMAGE EXPORT widget
        self.imexp_dest = Text(description='Output file',
                               value='output/output.png')
        self.imexp_btn = Button(description='Export')
        self.imexp_btn.on_click(self.export_img)
        self.imexp_wgt = HBox([self.imexp_dest, self.imexp_btn])

        # VIDEO EXPORT widget
        self.videxp_dest = Text(description='Output file',
                                value='output/output.mp4')
        self.videxp_start = BoundedIntText(value=0,
                                           min=0,
                                           max=len(self.df_img) - 1,
                                           step=1,
                                           description='From index',
                                           disabled=False)
        self.videxp_start.observe(self.vididx_changed, names='value')
        self.videxp_end = BoundedIntText(value=0,
                                         min=0,
                                         max=len(self.df_img) - 1,
                                         step=1,
                                         description='Until index',
                                         disabled=False)
        self.videxp_end.observe(self.vididx_changed, names='value')
        self.videxp_fps = FloatText(value=30, description='FPS')
        self.videxp_btn = Button(description='Export')
        self.videxp_btn.on_click(self.export_vid)

        self.videxp_wgt = VBox([
            HBox([self.videxp_start, self.videxp_end]),
            HBox([self.videxp_dest, self.videxp_fps]), self.videxp_btn
        ])
        self.exp_wgt = Tab(children=[self.imexp_wgt, self.videxp_wgt])
        self.exp_wgt.set_title(0, 'Export image')
        self.exp_wgt.set_title(1, 'Export video')
        self.exp_wgt.layout = Layout(margin='0 1em')

        super().__init__([
            self.imsel_wgt,
            VBox([
                self.img_title, self.imbox_wgt, self.play_wgt, self.detail_pane
            ]), self.exp_wgt
        ])
        self.idx_changed({'new': 0})

    def box_changed(self, change):
        if change['new'] is None:
            self.detail_wgt.data = {}
            self.crop_wgt.box = None
        else:
            new_idx = change['new']['index']
            self.detail_wgt.data = dict(self.df.loc[new_idx])
            self.crop_wgt.box = change['new']['box']

    def img_changed(self, change):
        new_img = change['new']
        self.detail_wgt.data = {}
        self.crop_wgt.img = new_img
        self.img_title.value = f'Image path: <a href="{new_img}">{new_img}</a>'
        self.imexp_dest.value = f'output/{Path(new_img).stem}.png'
        self.imexp_btn.button_style = ''
        self.imexp_btn.description = 'Export'
        self.imexp_btn.disabled = False

    def drop_changed(self, change):
        idx = self.df_img[self.df_img[self.img_col] == change['new']].index[0]
        self.idx = idx
        self.imbox_wgt.img = self.df_img.loc[idx, self.img_col]
        self.imbox_wgt.boxes = self.df_img.loc[idx, 'box_dict']
        self.idx_wgt.value = idx

    def idx_changed(self, change):
        self.idx = change['new']
        self.imbox_wgt.img = self.df_img.loc[self.idx, self.img_col]
        self.imbox_wgt.boxes = self.df_img.loc[self.idx, 'box_dict']
        self.drop_wgt.value = self.imbox_wgt.img

    def vididx_changed(self, change):
        start = self.videxp_start.value
        end = self.videxp_end.value
        self.videxp_dest.value = f'output/{start}_{end}.mp4'
        self.videxp_btn.button_style = ''
        self.videxp_btn.description = 'Export'
        self.videxp_btn.disabled = False

    def get_pilim_from_idx(self, idx):
        """Return the processed PIL image that belongs to an image index.
        """
        im = Image.open(self.df_img.loc[idx, self.img_col])
        draw = ImageDraw.Draw(im, mode='RGBA')

        box_dicts = self.df_img.loc[idx, 'box_dict']
        for bd in box_dicts:
            box = bd['box']
            draw.rectangle([(box.x_min, box.y_min), (box.x_max, box.y_max)],
                           fill=bd['style']['fill_color'],
                           outline=bd['style']['stroke_color'],
                           width=bd['style']['stroke_width'])

            fontfile = str(Path(__file__).parent / 'etc/Arial.ttf')

            # size*4 to make it look more similar to example in widget
            fontsize = bd['style']['font_size'] * 4
            font = ImageFont.truetype(fontfile, size=fontsize)
            w, h = draw.textsize(bd['text'], font=font)
            draw.text((box.x_min, box.y_min - h),
                      text=bd['text'],
                      fill=bd['style']['stroke_color'],
                      font=font)
        return im

    def export_img(self, button):
        self.imexp_btn.disabled = True
        self.imexp_btn.description = 'Exporting...'
        im = self.get_pilim_from_idx(self.idx)
        try:
            im.save(self.imexp_dest.value)
            self.imexp_btn.button_style = 'success'
            self.imexp_btn.description = 'Export Successful'
        except (IOError, KeyError) as e:
            self.imexp_btn.button_style = 'danger'
            self.imexp_btn.description = 'Export Failed'
            logging.exception('Export Failed')

    def export_vid(self, button):
        self.videxp_btn.disabled = True
        self.videxp_btn.description = 'Exporting...'
        fps = str(self.videxp_fps.value)
        writer = FFmpegWriter(self.videxp_dest.value,
                              inputdict={'-framerate': fps})

        for idx in tqdm(range(self.videxp_start.value, self.videxp_end.value)):
            im = self.get_pilim_from_idx(idx)
            writer.writeFrame(np.array(im))

        try:
            writer.close()
            self.videxp_btn.button_style = 'success'
            self.videxp_btn.description = 'Export successful'
        except OSError as e:
            self.videxp_btn.button_style = 'danger'
            self.videxp_btn.description = 'Export failed'
            logging.exception('Export Failed')
Exemple #13
0
class Iterator:
    def __init__(self,
                 images,
                 name="iterator",
                 show_name=True,
                 show_axis=False,
                 show_random=True,
                 fig_size=(10, 10),
                 buttons_vertical=False,
                 image_display_function=None):
        if len(images) == 0:
            raise Exception("No images provided")

        self.show_axis = show_axis
        self.name = name
        self.show_name = show_name
        self.show_random = show_random
        self.images = images
        self.max_pos = len(self.images) - 1
        self.pos = 0
        self.fig_size = fig_size
        self.buttons_vertical = buttons_vertical

        if image_display_function is None:
            self.image_display_function = self.__show_image
        else:
            self.image_display_function = image_display_function

        self.previous_button = self.__create_button("Previous",
                                                    (self.pos == 0),
                                                    self.__on_previous_clicked)
        self.next_button = self.__create_button("Next",
                                                (self.pos == self.max_pos),
                                                self.__on_next_clicked)
        self.save_button = self.__create_button("Save", False,
                                                self.__on_save_clicked)
        self.save_function = self.__save_function  # save_function

        buttons = [self.previous_button, self.next_button]

        if self.show_random:
            self.random_button = self.__create_button("Random", False,
                                                      self.__on_random_clicked)
            buttons.append(self.random_button)

        buttons.append(self.save_button)

        label_total = Label(value='/ {}'.format(len(self.images)))
        self.text_index = BoundedIntText(value=1, min=1, max=len(self.images))
        self.text_index.layout.width = '80px'
        self.text_index.layout.height = '35px'
        self.text_index.observe(self.__selected_index)
        self.out = Output()
        self.out.add_class(name)

        if self.buttons_vertical:
            self.all_widgets = HBox(children=[
                VBox(children=[HBox([self.text_index, label_total])] +
                     buttons), self.out
            ])
        else:
            self.all_widgets = VBox(children=[
                HBox([self.text_index, label_total]),
                HBox(children=buttons), self.out
            ])
        ## loading js library to perform html screenshots
        j_code = """
                require.config({
                    paths: {
                        html2canvas: "https://html2canvas.hertzen.com/dist/html2canvas.min"
                    }
                });
            """
        display(Javascript(j_code))

    def __create_button(self, description, disabled, function):
        button = Button(description=description)
        button.disabled = disabled
        button.on_click(function)
        return button

    def __show_image(self, image_path, index):
        img = Image.open(image_path)
        if self.show_name:
            print(os.path.basename(image_path))
        plt.figure(figsize=self.fig_size)
        if not self.show_axis:
            plt.axis('off')
        plt.imshow(img)
        plt.show()

    def __save_function(self, image_path, index):
        img_name = os.path.basename(image_path).split('.')[0]
        j_code = """
            require(["html2canvas"], function(html2canvas) {
                var element = $(".p-Widget.jupyter-widgets-output-area.output_wrapper.$it_name$")[0];
                console.log(element);
                 html2canvas(element).then(function (canvas) { 
                    var myImage = canvas.toDataURL(); 
                    var a = document.createElement("a"); 
                    a.href = myImage; 
                    a.download = "$img_name$.png"; 
                    a.click(); 
                    a.remove(); 
                });
            });
            """
        j_code = j_code.replace('$it_name$', self.name)
        j_code = j_code.replace('$img_name$', img_name)
        tmp_out = Output()
        with tmp_out:
            display(Javascript(j_code))
            tmp_out.clear_output()

    def __on_next_clicked(self, b):
        self.pos += 1
        self.__perform_action(self.pos, self.max_pos)

    def __on_save_clicked(self, b):
        self.save_function(self.images[self.pos], self.pos)

    def __perform_action(self, index, max_pos):
        self.next_button.disabled = (index == max_pos)
        self.previous_button.disabled = (index == 0)

        with self.out:
            self.out.clear_output()
        with self.out:
            self.image_display_function(self.images[index], index)

        self.text_index.unobserve(self.__selected_index)
        self.text_index.value = index + 1
        self.text_index.observe(self.__selected_index)

    def __on_previous_clicked(self, b):
        self.pos -= 1
        self.__perform_action(self.pos, self.max_pos)

    def __on_random_clicked(self, b):
        self.pos = random.randint(0, self.max_pos)
        self.__perform_action(self.pos, self.max_pos)

    def __selected_index(self, t):
        if t['owner'].value is None or t['name'] != 'value':
            return
        self.pos = t['new'] - 1
        self.__perform_action(self.pos, self.max_pos)

    def start_iteration(self):
        if self.max_pos < self.pos:
            print("No available images")
            return

        display(self.all_widgets)
        self.__perform_action(self.pos, self.max_pos)
class AnnotatorInterface(metaclass=abc.ABCMeta):

    def __init__(self,
                 dataset,
                 properties_and_values,  # properties { name: (MetaPropType, [possible values], optional label)}
                 output_path=None,
                 show_name=False,
                 show_axis=False,
                 fig_size=(10, 10),
                 buttons_vertical=False,
                 ds_name=None,
                 custom_display_function=None,
                 classes_to_annotate=None,
                 validate_function=None,
                 show_reset=True
                 ):

        for k, v in properties_and_values.items():
            if v[0] not in [MetaPropertiesType.UNIQUE, MetaPropertiesType.TEXT]:
                raise NotImplementedError(f"Cannot use {v[0]}!")

        self.dataset_orig = dataset
        self.properties_and_values = properties_and_values
        self.show_axis = show_axis
        self.show_reset = show_reset
        self.show_name = show_name

        if classes_to_annotate is None:  # if classes_to_annotate is None, all the classes would be annotated
            self.classes_to_annotate = self.dataset_orig.get_categories_names()  # otherwise, the only the classes in the list

        if ds_name is None:
            self.name = (self.dataset_orig.dataset_root_param.split('/')[-1]).split('.')[0]  # get_original_file_name
        else:
            self.name = ds_name

        self.set_output(output_path)

        print("{} {}".format(labels_str.info_new_ds, self.file_path_for_json))

        self.current_pos = 0
        self.mapping, self.dataset_annotated = self.create_results_dict(self.file_path_for_json)
        self.set_objects()

        self.fig_size = fig_size
        self.buttons_vertical = buttons_vertical

        self.current_image = {}

        label_total = self.create_label_total()

        # create buttons
        buttons = self.create_buttons()

        self.updated = False

        self.validation_show = HTML(value="")
        self.out = Output()
        self.out.add_class("my_canvas_class")

        self.checkboxes = {}
        self.radiobuttons = {}
        self.bounded_text = {}
        self.box_text = {}

        labels = self.create_check_radio_boxes()

        self.validate = not validate_function is None
        if self.validate:
            self.validate_function = validate_function

        self.set_display_function(custom_display_function)

        output_layout = self.set_check_radio_boxes_layout(labels)

        self.all_widgets = VBox(children=
                                [HBox([self.text_index, label_total]),
                                 HBox(buttons),
                                 self.validation_show,
                                 HBox([self.out,
                                       VBox(output_layout)])])

        self.load_js()

    @abc.abstractmethod
    def set_objects(self):
        pass

    @abc.abstractmethod
    def set_display_function(self, custom_display_function):
        pass

    def set_output(self, output_path):
        if output_path is None:
            name = self.dataset_orig.dataset_root_param
            self.output_directory = name.replace(os.path.basename(name), "")
        else:
            self.output_directory = output_path

        self.file_path_for_json = os.path.join(self.output_directory, self.name + "_ANNOTATED.json")

    def create_label_total(self):
        label_total = Label(value='/ {}'.format(len(self.objects)))
        self.text_index = BoundedIntText(value=1, min=1, max=len(self.objects))
        self.text_index.layout.width = '80px'
        self.text_index.layout.height = '35px'
        self.text_index.observe(self.selected_index)
        return label_total

    def create_buttons(self):
        # create buttons
        self.previous_button = self.create_button(labels_str.str_btn_prev, (self.current_pos == 0),
                                                  self.on_previous_clicked)
        self.next_button = self.create_button(labels_str.str_btn_next, (self.current_pos == self.max_pos),
                                              self.on_next_clicked)
        self.save_button = self.create_button(labels_str.str_btn_download, False, self.on_save_clicked)
        self.save_function = self.save_function  # save_function

        if self.show_reset:
            self.reset_button = self.create_button(labels_str.str_btn_reset, False, self.on_reset_clicked)
            buttons = [self.previous_button, self.next_button, self.reset_button, self.save_button]
        else:
            buttons = [self.previous_button, self.next_button, self.save_button]
        return buttons

    def create_check_radio_boxes(self):
        labels = dict()
        for k_name, v in self.properties_and_values.items():

            if len(v) == 3:
                label = v[2]
            elif MetaPropertiesType.TEXT.value == v[0].value and len(v) == 2:
                label = v[1]
            else:
                label = k_name

            labels[k_name] = label

            if MetaPropertiesType.UNIQUE.value == v[0].value:  # radiobutton
                self.radiobuttons[k_name] = RadioButtons(name=k_name, options=v[1],
                                                         disabled=False,
                                                         indent=False)
            elif MetaPropertiesType.COMPOUND.value == v[0].value:  # checkbox
                self.checkboxes[k_name] = [Checkbox(False, indent=False, name=k_name,
                                                    description=prop_name) for prop_name in v[1]]
            elif MetaPropertiesType.CONTINUE.value == v[0].value:
                self.bounded_text[k_name] = BoundedFloatText(value=v[1][0], min=v[1][0], max=v[1][1])

            elif MetaPropertiesType.TEXT.value == v[0].value:
                self.box_text[k_name] = Textarea(disabled=False)

        return labels

    def set_check_radio_boxes_layout(self, labels):
        output_layout = []
        self.check_radio_boxes_layout = {}
        for rb_k, rb_v in self.radiobuttons.items():
            rb_v.layout.width = '180px'
            rb_v.observe(self.checkbox_changed)
            rb_v.add_class(rb_k)
            html_title = HTML(value="<b>" + labels[rb_k] + "</b>")
            self.check_radio_boxes_layout[rb_k] = VBox([rb_v])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[rb_k]]))

        for cb_k, cb_i in self.checkboxes.items():
            for cb in cb_i:
                cb.layout.width = '180px'
                cb.observe(self.checkbox_changed)
                cb.add_class(cb_k)
            html_title = HTML(value="<b>" + labels[cb_k] + "</b>")
            self.check_radio_boxes_layout[cb_k] = VBox(children=[cb for cb in cb_i])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[cb_k]]))

        for bf_k, bf in self.bounded_text.items():
            bf.layout.width = '80px'
            bf.layout.height = '35px'
            bf.observe(self.checkbox_changed)
            bf.add_class(bf_k)
            html_title = HTML(value="<b>" + labels[bf_k] + "</b>")
            self.check_radio_boxes_layout[bf_k] = VBox([bf])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[bf_k]]))

        for tb_k, tb_i in self.box_text.items():
            tb_i.layout.width = '500px'
            tb_i.observe(self.checkbox_changed)
            tb_i.add_class(tb_k)
            html_title = HTML(value="<b>" + labels[tb_k] + "</b>")
            self.check_radio_boxes_layout[tb_k] = VBox([tb_i])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[tb_k]]))

        return output_layout

    def load_js(self):
        ## loading js library to perform html screenshots
        j_code = """
                        require.config({
                            paths: {
                                html2canvas: "https://html2canvas.hertzen.com/dist/html2canvas.min"
                            }
                        });
                    """
        display(Javascript(j_code))

    def change_check_radio_boxes_value(self, current_ann):
        for m_k, m_v in self.properties_and_values.items():
            if m_v[0].value == MetaPropertiesType.UNIQUE.value:  # radiobutton
                self.radiobuttons[m_k].unobserve(self.checkbox_changed)
                self.radiobuttons[m_k].value = current_ann[m_k] if m_k in current_ann.keys() else None
                self.radiobuttons[m_k].observe(self.checkbox_changed)
            elif m_v[0].value == MetaPropertiesType.COMPOUND.value:  # checkbox
                for cb_i, cb_v in enumerate(self.checkboxes[m_k]):
                    cb_v.unobserve(self.checkbox_changed)
                    if m_k in current_ann.keys():
                        if cb_v.description in current_ann[m_k].keys():
                            cb_v.value = current_ann[m_k][cb_v.description]
                        else:
                            cb_v.value = False
                    else:
                        cb_v.value = False
                    cb_v.observe(self.checkbox_changed)
            elif m_v[0].value == MetaPropertiesType.CONTINUE.value:  # textbound
                self.bounded_text[m_k].unobserve(self.checkbox_changed)
                self.bounded_text[m_k].value = float(current_ann[m_k]) if m_k in current_ann.keys() else \
                    self.bounded_text[m_k].min
                self.bounded_text[m_k].observe(self.checkbox_changed)
            elif m_v[0].value == MetaPropertiesType.TEXT.value:  # text
                self.box_text[m_k].unobserve(self.checkbox_changed)
                self.box_text[m_k].value = current_ann[m_k] if m_k in current_ann.keys() else ""
                self.box_text[m_k].observe(self.checkbox_changed)

    def create_results_dict(self, file_path):
        mapping = {}
        mapping["annotated_ids"] = set()  # key: object_id from dataset, values=[(annotations done)]
        mapping["categories_counter"] = dict.fromkeys([c for c in self.dataset_orig.get_categories_names()], 0)
        self.mapping = mapping

        if not os.path.exists(file_path):  # it does exist __ANNOTATED in the output directory
            with open(self.dataset_orig.dataset_root_param, 'r') as input_json_file:
                dataset_annotated = json.load(input_json_file)
                input_json_file.close()

            # take the same metaproperties already in file if it's not empty (like a new file)
            meta_prop = dataset_annotated['meta_properties'] if 'meta_properties' in dataset_annotated.keys() else []

            # adds the new annotations categories to dataset if it doesn't exist
            for k_name, v in self.properties_and_values.items():

                new_mp_to_append = {
                    "name": k_name,
                    "type": v[0].value,
                }

                if len(v) > 1:
                    new_mp_to_append["values"] = sorted([p for p in v[1]])

                names_prop_in_file = {m_p['name']: m_p for m_i, m_p in enumerate(dataset_annotated[
                                                                                     'meta_properties'])} if 'meta_properties' in dataset_annotated.keys() else None

                if 'meta_properties' not in dataset_annotated.keys():  # it is a new file
                    meta_prop.append(new_mp_to_append)
                    dataset_annotated['meta_properties'] = []

                elif names_prop_in_file is not None and k_name not in names_prop_in_file.keys():
                    # if there is a property with the same in meta_properties, it must be the same structure as the one proposed
                    meta_prop.append(new_mp_to_append)
                    self.update_annotation_counter_and_current_pos(dataset_annotated)

                elif names_prop_in_file is not None and k_name in names_prop_in_file.keys() and \
                        names_prop_in_file[k_name] == new_mp_to_append:
                    # we don't append because it's already there
                    self.update_annotation_counter_and_current_pos(dataset_annotated)

                else:
                    raise NameError("An annotation with the same name {} "
                                    "already exist in dataset {}, and it has different structure. Check properties.".format(
                        k_name, self.dataset_orig.dataset_root_param))

                # if k_name is in name_props_in_file and it's the same structure. No update is done.
            dataset_annotated['meta_properties'] = dataset_annotated['meta_properties'] + meta_prop
        else:
            with open(file_path, 'r') as classification_file:
                dataset_annotated = json.load(classification_file)
                classification_file.close()

            self.update_annotation_counter_and_current_pos(dataset_annotated)
        return mapping, dataset_annotated

    @abc.abstractmethod
    def add_annotation_to_mapping(self, ann):
        pass

    @abc.abstractmethod
    def update_mapping_from_whole_dataset(self):
        pass

    @abc.abstractmethod
    def update_annotation_counter_and_current_pos(self, dataset_annotated):
        pass

    def execute_validation(self, ann):
        if self.validate:
            if self.validate_function(ann):
                self.validation_show.value = labels_str.srt_validation_not_ok
            else:
                self.validation_show.value = labels_str.srt_validation_ok

    @abc.abstractmethod
    def show_name_func(self, image_record, path_img):
        pass

    @abc.abstractmethod
    def checkbox_changed(self, b):
        pass

    def create_button(self, description, disabled, function):
        button = Button(description=description)
        button.disabled = disabled
        button.on_click(function)
        return button

    @abc.abstractmethod
    def show_image(self, image_record, ann_key):
        pass

    @abc.abstractmethod
    def save_state(self):
        pass

    def save_function(self, image_path):
        img_name = os.path.basename(image_path).split('.')[0]
        j_code = """
            require(["html2canvas"], function(html2canvas) {
                var element = $(".p-Widget.jupyter-widgets-output-area.output_wrapper.$it_name$")[0];
                console.log(element);
                 html2canvas(element).then(function (canvas) { 
                    var myImage = canvas.toDataURL(); 
                    var a = document.createElement("a"); 
                    a.href = myImage; 
                    a.download = "$img_name$.png"; 
                    a.click(); 
                    a.remove(); 
                });
            });
            """
        j_code = j_code.replace('$it_name$', "my_canvas_class")
        j_code = j_code.replace('$img_name$', img_name)
        tmp_out = Output()
        with tmp_out:
            display(Javascript(j_code))
            tmp_out.clear_output()

    @abc.abstractmethod
    def perform_action(self):
        pass

    @abc.abstractmethod
    def get_image_record(self):
        pass

    @abc.abstractmethod
    def on_reset_clicked(self, b):
        pass

    def on_previous_clicked(self, b):
        self.save_state()
        self.current_pos -= 1
        self.perform_action()

    def on_next_clicked(self, b):
        self.save_state()
        self.current_pos += 1
        self.perform_action()

    @abc.abstractmethod
    def on_save_clicked(self, b):
        pass

    def selected_index(self, t):
        if t['owner'].value is None or t['name'] != 'value':
            return
        self.current_pos = t['new'] - 1
        self.perform_action()

    def start_annotation(self):
        if self.max_pos < self.current_pos:
            print(labels_str.info_no_more_images)
            return
        display(self.all_widgets)
        self.perform_action()

    @abc.abstractmethod
    def print_statistics(self):
        pass

    @abc.abstractmethod
    def print_results(self):
        pass
Exemple #15
0
class SVGTab(object):

    def __init__(self):
        # tab_height = '520px'
        # tab_layout = Layout(width='900px',   # border='2px solid black',
        #                     height=tab_height, overflow_y='scroll')

        self.output_dir = '.'

        constWidth = '180px'

#        self.fig = plt.figure(figsize=(6, 6))
        # self.fig = plt.figure(figsize=(7, 7))

        max_frames = 1
        self.svg_plot = interactive(self.plot_svg, frame=(0, max_frames), continuous_update=False)

# https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20List.html#Play-(Animation)-widget
        # play = widgets.Play(
        # #     interval=10,
        #     value=50,
        #     min=0,
        #     max=100,
        #     step=1,
        #     description="Press play",
        #     disabled=False
        # )
        # slider = widgets.IntSlider()
        # widgets.jslink((play, 'value'), (slider, 'value'))
        # widgets.HBox([play, slider])

        # "plot_size" controls the size of the tab height, not the plot (rf. figsize for that)
        plot_size = '500px'  # small: 
        plot_size = '750px'  # medium
        plot_size = '700px'  # medium
        plot_size = '600px'  # medium
        self.svg_plot.layout.width = plot_size
        self.svg_plot.layout.height = plot_size
        self.use_defaults = True
        self.show_nucleus = 1  # 0->False, 1->True in Checkbox!
        self.show_edge = 1  # 0->False, 1->True in Checkbox!
        self.show_tracks = 1  # 0->False, 1->True in Checkbox!
        self.trackd = {}  # dictionary to hold cell IDs and their tracks: (x,y) pairs
        # self.scale_radius = 1.0
        self.axes_min = 0.0
        self.axes_max = 2000   # hmm, this can change (TODO?)

        self.max_frames = BoundedIntText(
            min=0, max=99999, value=max_frames,
            description='Max',
            layout=Layout(width='160px'),
#            layout=Layout(flex='1 1 auto', width='auto'),  #Layout(width='160px'),
        )
        self.max_frames.observe(self.update_max_frames)

        self.show_nucleus_checkbox= Checkbox(
            description='nucleus', value=True, disabled=False,
            layout=Layout(width=constWidth),
#            layout=Layout(flex='1 1 auto', width='auto'),  #Layout(width='160px'),
        )
        self.show_nucleus_checkbox.observe(self.show_nucleus_cb)

        self.show_edge_checkbox= Checkbox(
            description='edge', value=True, disabled=False,
            layout=Layout(width=constWidth),
#            layout=Layout(flex='1 1 auto', width='auto'),  #Layout(width='160px'),
        )
        self.show_edge_checkbox.observe(self.show_edge_cb)

        self.show_tracks_checkbox= Checkbox(
            description='tracks', value=True, disabled=False,
            layout=Layout(width=constWidth),
#            layout=Layout(flex='1 1 auto', width='auto'),  #Layout(width='160px'),
        )
        self.show_tracks_checkbox.observe(self.show_tracks_cb)

#        row1 = HBox([Label('(select slider: drag or left/right arrows)'), 
#            self.max_frames, VBox([self.show_nucleus_checkbox, self.show_edge_checkbox])])
#            self.max_frames, self.show_nucleus_checkbox], layout=Layout(width='500px'))

#        self.tab = VBox([row1,self.svg_plot], layout=tab_layout)

        items_auto = [Label('select slider: drag or left/right arrows'), 
            self.max_frames, 
            self.show_nucleus_checkbox,  
            self.show_edge_checkbox, 
            self.show_tracks_checkbox, 
         ]
#row1 = HBox([Label('(select slider: drag or left/right arrows)'), 
#            max_frames, show_nucleus_checkbox, show_edge_checkbox], 
#            layout=Layout(width='800px'))
        box_layout = Layout(display='flex',
                    flex_flow='row',
                    align_items='stretch',
                    width='70%')
        row1 = Box(children=items_auto, layout=box_layout)

        if (hublib_flag):
            self.download_button = Download('svg.zip', style='warning', icon='cloud-download', 
                                            tooltip='You need to allow pop-ups in your browser', cb=self.download_cb)
            download_row = HBox([self.download_button.w, Label("Download all cell plots (browser must allow pop-ups).")])
    #        self.tab = VBox([row1, self.svg_plot, self.download_button.w], layout=tab_layout)
    #        self.tab = VBox([row1, self.svg_plot, self.download_button.w])
            self.tab = VBox([row1, self.svg_plot, download_row])
        else:
            self.tab = VBox([row1, self.svg_plot])


    # def update(self, rdir=''):
    def update(self, rdir=''):
        # with debug_view:
        #     print("SVG: update rdir=", rdir)        

        if rdir:
            self.output_dir = rdir

        all_files = sorted(glob.glob(os.path.join(self.output_dir, 'snapshot*.svg')))
        if len(all_files) > 0:
            last_file = all_files[-1]
            # Note! the following will trigger: self.max_frames.observe(self.update_max_frames)
            self.max_frames.value = int(last_file[-12:-4])  # assumes naming scheme: "snapshot%08d.svg"

        # with debug_view:
        #     print("SVG: added %s files" % len(all_files))


    def download_cb(self):
        file_str = os.path.join(self.output_dir, '*.svg')
        # print('zip up all ',file_str)
        with zipfile.ZipFile('svg.zip', 'w') as myzip:
            for f in glob.glob(file_str):
                myzip.write(f, os.path.basename(f))   # 2nd arg avoids full filename path in the archive

    def show_nucleus_cb(self, b):
        global current_frame
        if (self.show_nucleus_checkbox.value):
            self.show_nucleus = 1
        else:
            self.show_nucleus = 0
#        self.plot_svg(self,current_frame)
        self.svg_plot.update()

    def show_edge_cb(self, b):
        if (self.show_edge_checkbox.value):
            self.show_edge = 1
        else:
            self.show_edge = 0
        self.svg_plot.update()

    def show_tracks_cb(self, b):
        if (self.show_tracks_checkbox.value):
            self.show_tracks = 1
        else:
            self.show_tracks = 0
        # print('--- show_tracks_cb: calling svg_plot.update()')
        # if (not self.show_tracks):
        #     self.svg_plot.update()
        # else:
        if (self.show_tracks):
            self.create_all_tracks()
        self.svg_plot.update()


    # Note! this is called for EACH change to "Max" frames, which is with every new .svg file created!
    def update_max_frames(self,_b): 
        self.svg_plot.children[0].max = self.max_frames.value
        # if (self.show_tracks):
        #     print('--- update_max_frames: calling create_all_tracks')
        #     self.create_all_tracks()

    #-----------------------------------------------------
    def circles(self, x, y, s, c='b', vmin=None, vmax=None, **kwargs):
        """
        See https://gist.github.com/syrte/592a062c562cd2a98a83 

        Make a scatter plot of circles. 
        Similar to plt.scatter, but the size of circles are in data scale.
        Parameters
        ----------
        x, y : scalar or array_like, shape (n, )
            Input data
        s : scalar or array_like, shape (n, ) 
            Radius of circles.
        c : color or sequence of color, optional, default : 'b'
            `c` can be a single color format string, or a sequence of color
            specifications of length `N`, or a sequence of `N` numbers to be
            mapped to colors using the `cmap` and `norm` specified via kwargs.
            Note that `c` should not be a single numeric RGB or RGBA sequence 
            because that is indistinguishable from an array of values
            to be colormapped. (If you insist, use `color` instead.)  
            `c` can be a 2-D array in which the rows are RGB or RGBA, however. 
        vmin, vmax : scalar, optional, default: None
            `vmin` and `vmax` are used in conjunction with `norm` to normalize
            luminance data.  If either are `None`, the min and max of the
            color array is used.
        kwargs : `~matplotlib.collections.Collection` properties
            Eg. alpha, edgecolor(ec), facecolor(fc), linewidth(lw), linestyle(ls), 
            norm, cmap, transform, etc.
        Returns
        -------
        paths : `~matplotlib.collections.PathCollection`
        Examples
        --------
        a = np.arange(11)
        circles(a, a, s=a*0.2, c=a, alpha=0.5, ec='none')
        plt.colorbar()
        License
        --------
        This code is under [The BSD 3-Clause License]
        (http://opensource.org/licenses/BSD-3-Clause)
        """

        if np.isscalar(c):
            kwargs.setdefault('color', c)
            c = None

        if 'fc' in kwargs:
            kwargs.setdefault('facecolor', kwargs.pop('fc'))
        if 'ec' in kwargs:
            kwargs.setdefault('edgecolor', kwargs.pop('ec'))
        if 'ls' in kwargs:
            kwargs.setdefault('linestyle', kwargs.pop('ls'))
        if 'lw' in kwargs:
            kwargs.setdefault('linewidth', kwargs.pop('lw'))
        # You can set `facecolor` with an array for each patch,
        # while you can only set `facecolors` with a value for all.

        zipped = np.broadcast(x, y, s)
        patches = [Circle((x_, y_), s_)
                for x_, y_, s_ in zipped]
        collection = PatchCollection(patches, **kwargs)
        if c is not None:
            c = np.broadcast_to(c, zipped.shape).ravel()
            collection.set_array(c)
            collection.set_clim(vmin, vmax)

        ax = plt.gca()
        ax.add_collection(collection)
        ax.autoscale_view()
        # plt.draw_if_interactive()
        if c is not None:
            plt.sci(collection)
        # return collection

    #-------------------------
    def create_all_tracks(self, rdir=''):
        if rdir:
            print('create_all_tracks():  rdir=',rdir)
            self.output_dir = rdir

        # current_frame = frame
        # check: if 0th .svg snapshot file doesn't exist, exit
        fname = "snapshot%08d.svg" % 0  # assume [0:max] snapshots exist
        full_fname = os.path.join(self.output_dir, fname)
        if not os.path.isfile(full_fname):
            print('create_all_tracks():  0th svg missing, return')
            return

        self.trackd.clear()
        # print("----- create_all_tracks") 
        for frame in range(self.max_frames.value):
            fname = "snapshot%08d.svg" % frame  # assume [0:max] snapshots exist
            full_fname = os.path.join(self.output_dir, fname)
            # with debug_view:
            #     print("plot_svg:", full_fname) 
            # if not os.path.isfile(full_fname):
            #     print("Once output files are generated, click the slider.")   
            #     return

            #  print('\n---- ' + fname + ':')
            tree = ET.parse(full_fname)
            root = tree.getroot()
            numChildren = 0
            for child in root:
                if ('id' in child.attrib.keys()):
                    # print('-------- found tissue!!')
                    tissue_parent = child
                    break

            # print('------ search tissue')
            cells_parent = None

            for child in tissue_parent:
                # print('attrib=',child.attrib)
                if (child.attrib['id'] == 'cells'):
                    # print('-------- found cells, setting cells_parent')
                    cells_parent = child
                    break
                numChildren += 1

            num_cells = 0
            for child in cells_parent:
                for circle in child:  # two circles in each child: outer + nucleus
                    #  circle.attrib={'cx': '1085.59','cy': '1225.24','fill': 'rgb(159,159,96)','r': '6.67717','stroke': 'rgb(159,159,96)','stroke-width': '0.5'}
                    #      print('  --- cx,cy=',circle.attrib['cx'],circle.attrib['cy'])
                    xval = float(circle.attrib['cx'])

                    # test for bogus x,y locations (rwh TODO: use max of domain?)
                    too_large_val = 10000.
                    if (np.fabs(xval) > too_large_val):
                        print("bogus xval=", xval)
                        break
                    yval = float(circle.attrib['cy'])
                    if (np.fabs(yval) > too_large_val):
                        print("bogus xval=", xval)
                        break

                    # if this cell ID (and x,y) is not yet in our trackd dict, add it
                    if (child.attrib['id'] in self.trackd.keys()):
                        self.trackd[child.attrib['id']] = np.vstack((self.trackd[child.attrib['id']], [ xval, yval ]))
                    else:
                        self.trackd[child.attrib['id']] = np.array( [ xval, yval ])

                    # xlist.append(xval)
                    # ylist.append(yval)

                    # if (self.show_nucleus == 0):
                    break  # we don't care about the 2nd circle (nucleus)

                num_cells += 1

                # if num_cells > 3:   # for debugging
                #   print(fname,':  num_cells= ',num_cells," --- debug exit.")
                #   sys.exit(1)
                #   break

                # print(fname,':  num_cells= ',num_cells)

            # xvals = np.array(xlist)
            # yvals = np.array(ylist)

        # print('-- self.trackd=',self.trackd)
        # if (self.show_tracks):
        #     # print('len(trackd.keys()) = ',len(trackd.keys()))
        #     print('self.trackd= ',self.trackd)
        #     for key in self.trackd.keys():
        #         if (len(self.trackd[key].shape) == 2):
        #             print('plotting tracks')
        #             xtracks = self.trackd[key][:,0]
        #             ytracks = self.trackd[key][:,1]
        #             plt.plot(xtracks,ytracks)

    #-------------------------
    # def plot_svg(self, frame, rdel=''):
    def plot_svg(self, frame):
        # global current_idx, axes_max
        global current_frame
        current_frame = frame
        fname = "snapshot%08d.svg" % frame
        full_fname = os.path.join(self.output_dir, fname)
        # with debug_view:
            # print("plot_svg:", full_fname) 
        # print("-- plot_svg:", full_fname) 
        if not os.path.isfile(full_fname):
            print("Once output files are generated, click the slider.")   
            return

        xlist = deque()
        ylist = deque()
        rlist = deque()
        rgb_list = deque()

        #  print('\n---- ' + fname + ':')
#        tree = ET.parse(fname)
        tree = ET.parse(full_fname)
        root = tree.getroot()
        #  print('--- root.tag ---')
        #  print(root.tag)
        #  print('--- root.attrib ---')
        #  print(root.attrib)
        #  print('--- child.tag, child.attrib ---')
        numChildren = 0
        for child in root:
            #    print(child.tag, child.attrib)
            #    print("keys=",child.attrib.keys())
            if self.use_defaults and ('width' in child.attrib.keys()):
                self.axes_max = float(child.attrib['width'])
                # print("debug> found width --> axes_max =", axes_max)
            if child.text and "Current time" in child.text:
                svals = child.text.split()
                # title_str = "(" + str(current_idx) + ") Current time: " + svals[2] + "d, " + svals[4] + "h, " + svals[7] + "m"
                # title_str = "Current time: " + svals[2] + "d, " + svals[4] + "h, " + svals[7] + "m"
                title_str = svals[2] + "d, " + svals[4] + "h, " + svals[7] + "m"

            # print("width ",child.attrib['width'])
            # print('attrib=',child.attrib)
            # if (child.attrib['id'] == 'tissue'):
            if ('id' in child.attrib.keys()):
                # print('-------- found tissue!!')
                tissue_parent = child
                break

        # print('------ search tissue')
        cells_parent = None

        for child in tissue_parent:
            # print('attrib=',child.attrib)
            if (child.attrib['id'] == 'cells'):
                # print('-------- found cells, setting cells_parent')
                cells_parent = child
                break
            numChildren += 1

        num_cells = 0
        #  print('------ search cells')
        for child in cells_parent:
            #    print(child.tag, child.attrib)
            #    print('attrib=',child.attrib)
            for circle in child:  # two circles in each child: outer + nucleus
                #  circle.attrib={'cx': '1085.59','cy': '1225.24','fill': 'rgb(159,159,96)','r': '6.67717','stroke': 'rgb(159,159,96)','stroke-width': '0.5'}
                #      print('  --- cx,cy=',circle.attrib['cx'],circle.attrib['cy'])
                xval = float(circle.attrib['cx'])

                s = circle.attrib['fill']
                # print("s=",s)
                # print("type(s)=",type(s))
                if (s[0:3] == "rgb"):  # if an rgb string, e.g. "rgb(175,175,80)" 
                    rgb = list(map(int, s[4:-1].split(",")))  
                    rgb[:] = [x / 255. for x in rgb]
                else:     # otherwise, must be a color name
                    rgb_tuple = mplc.to_rgb(mplc.cnames[s])  # a tuple
                    rgb = [x for x in rgb_tuple]

                # test for bogus x,y locations (rwh TODO: use max of domain?)
                too_large_val = 10000.
                if (np.fabs(xval) > too_large_val):
                    print("bogus xval=", xval)
                    break
                yval = float(circle.attrib['cy'])
                if (np.fabs(yval) > too_large_val):
                    print("bogus xval=", xval)
                    break

                rval = float(circle.attrib['r'])
                # if (rgb[0] > rgb[1]):
                #     print(num_cells,rgb, rval)
                xlist.append(xval)
                ylist.append(yval)
                rlist.append(rval)
                rgb_list.append(rgb)

                # For .svg files with cells that *have* a nucleus, there will be a 2nd
                if (self.show_nucleus == 0):
                #if (not self.show_nucleus):
                    break

            num_cells += 1

            # if num_cells > 3:   # for debugging
            #   print(fname,':  num_cells= ',num_cells," --- debug exit.")
            #   sys.exit(1)
            #   break

            # print(fname,':  num_cells= ',num_cells)

        xvals = np.array(xlist)
        yvals = np.array(ylist)
        rvals = np.array(rlist)
        rgbs = np.array(rgb_list)
        # print("xvals[0:5]=",xvals[0:5])
        # print("rvals[0:5]=",rvals[0:5])
        # print("rvals.min, max=",rvals.min(),rvals.max())

        # rwh - is this where I change size of render window?? (YES - yipeee!)
        #   plt.figure(figsize=(6, 6))
        #   plt.cla()
        title_str += " (" + str(num_cells) + " agents)"
        #   plt.title(title_str)
        #   plt.xlim(axes_min,axes_max)
        #   plt.ylim(axes_min,axes_max)
        #   plt.scatter(xvals,yvals, s=rvals*scale_radius, c=rgbs)

        # TODO: make figsize a function of plot_size? What about non-square plots?
        # self.fig = plt.figure(figsize=(9, 9))
        # self.fig = plt.figure(figsize=(18, 18))
        # self.fig = plt.figure(figsize=(15, 15))  # 
        self.fig = plt.figure(figsize=(9, 9))  # 

#        axx = plt.axes([0, 0.05, 0.9, 0.9])  # left, bottom, width, height
#        axx = fig.gca()
#        print('fig.dpi=',fig.dpi) # = 72

        #   im = ax.imshow(f.reshape(100,100), interpolation='nearest', cmap=cmap, extent=[0,20, 0,20])
        #   ax.xlim(axes_min,axes_max)
        #   ax.ylim(axes_min,axes_max)

        # convert radii to radii in pixels
        # ax2 = self.fig.gca()
        # N = len(xvals)
        # rr_pix = (ax2.transData.transform(np.vstack([rvals, rvals]).T) -
        #             ax2.transData.transform(np.vstack([np.zeros(N), np.zeros(N)]).T))
        # rpix, _ = rr_pix.T

        # markers_size = (144. * rpix / self.fig.dpi)**2   # = (2*rpix / fig.dpi * 72)**2
        # markers_size = markers_size/4000000.
        # print('max=',markers_size.max())

        #rwh - temp fix - Ah, error only occurs when "edges" is toggled on
        if (self.show_edge):
            try:
                # plt.scatter(xvals,yvals, s=markers_size, c=rgbs, edgecolor='black', linewidth=0.5)
                self.circles(xvals,yvals, s=rvals, color=rgbs, edgecolor='black', linewidth=0.5)
                # cell_circles = self.circles(xvals,yvals, s=rvals, color=rgbs, edgecolor='black', linewidth=0.5)
                # plt.sci(cell_circles)
            except (ValueError):
                pass
        else:
            # plt.scatter(xvals,yvals, s=markers_size, c=rgbs)
            self.circles(xvals,yvals, s=rvals, color=rgbs)

        if (self.show_tracks):
            for key in self.trackd.keys():
                xtracks = self.trackd[key][:,0]
                ytracks = self.trackd[key][:,1]
                plt.plot(xtracks[0:frame],ytracks[0:frame],  linewidth=5)

        plt.xlim(self.axes_min, self.axes_max)
        plt.ylim(self.axes_min, self.axes_max)
        #   ax.grid(False)
#        axx.set_title(title_str)
        plt.title(title_str)
Exemple #16
0
class ConfigTab(object):
    def __init__(self):

        #        micron_units = HTMLMath(value=r"$\mu M$")
        micron_units = Label(
            'micron')  # use "option m" (Mac, for micro symbol)
        #        micron_units = Label('microns')   # use "option m" (Mac, for micro symbol)

        constWidth = '180px'
        # tab_height = '400px'
        tab_height = '500px'
        #        tab_layout = Layout(width='900px',   # border='2px solid black',
        #        tab_layout = Layout(width='850px',   # border='2px solid black',
        #                            height=tab_height, overflow_y='scroll',)
        #        np_tab_layout = Layout(width='800px',  # border='2px solid black',
        #                               height='350px', overflow_y='scroll',)

        # my_domain = [0,0,-10, 2000,2000,10, 20,20,20]  # [x,y,zmin,  x,y,zmax, x,y,zdelta]
        #        label_domain = Label('Domain ($\mu M$):')
        label_domain = Label('Domain (micron):')
        stepsize = 10
        disable_domain = False
        self.xmin = BoundedFloatText(
            step=stepsize,
            # description='$X_{min}$',
            min=-5000,
            description='Xmin',
            disabled=disable_domain,
            layout=Layout(width=constWidth),
        )
        self.ymin = BoundedFloatText(
            step=stepsize,
            description='Ymin',
            min=-5000,
            disabled=disable_domain,
            layout=Layout(width=constWidth),
        )
        self.zmin = BoundedFloatText(
            step=stepsize,
            description='Zmin',
            min=-5000,
            disabled=disable_domain,
            layout=Layout(width=constWidth),
        )
        self.xmax = BoundedFloatText(
            step=stepsize,
            description='Xmax',
            max=5000,
            disabled=disable_domain,
            layout=Layout(width=constWidth),
        )
        self.ymax = BoundedFloatText(
            step=stepsize,
            description='Ymax',
            max=5000,
            disabled=disable_domain,
            layout=Layout(width=constWidth),
        )
        self.zmax = BoundedFloatText(
            step=stepsize,
            description='Zmax',
            max=5000,
            disabled=disable_domain,
            layout=Layout(width=constWidth),
        )
        #            description='$Time_{max}$',
        self.tmax = BoundedFloatText(
            min=0.,
            max=100000000,
            step=stepsize,
            description='Max Time',
            layout=Layout(width=constWidth),
        )
        self.xdelta = BoundedFloatText(
            min=1.,
            description='dx',  # '∆x',  # Mac: opt-j for delta
            disabled=disable_domain,
            layout=Layout(width=constWidth),
        )

        self.ydelta = BoundedFloatText(
            min=1.,
            description='dy',
            disabled=True,
            layout=Layout(width=constWidth),
        )
        self.zdelta = BoundedFloatText(
            min=1.,
            description='dz',
            disabled=disable_domain,
            layout=Layout(width=constWidth),
        )

        def xdelta_cb(b):
            self.ydelta.value = self.xdelta.value
            self.zdelta.value = 0.5 * (self.xdelta.value + self.ydelta.value)
            self.zmin.value = -0.5 * self.zdelta.value
            self.zmax.value = 0.5 * self.zdelta.value

        self.xdelta.observe(xdelta_cb)
        """
        self.tdelta = BoundedFloatText(
            min=0.01,
            description='$Time_{delta}$',
            layout=Layout(width=constWidth),
        )
        """
        """
        self.toggle2D = Checkbox(
            description='2-D',
            layout=Layout(width=constWidth),
        )
        def toggle2D_cb(b):
            if (self.toggle2D.value):
                #zmin.disabled = zmax.disabled = zdelta.disabled = True
                zmin.disabled = True
                zmax.disabled = True
                zdelta.disabled = True
            else:
                zmin.disabled = False
                zmax.disabled = False
                zdelta.disabled = False
            
        self.toggle2D.observe(toggle2D_cb)
        """

        x_row = HBox([self.xmin, self.xmax, self.xdelta])
        y_row = HBox([self.ymin, self.ymax, self.ydelta])
        z_row = HBox([self.zmin, self.zmax, self.zdelta])

        self.omp_threads = BoundedIntText(
            min=1,
            max=4,
            description='# threads',
            layout=Layout(width=constWidth),
        )

        # self.toggle_prng = Checkbox(
        #     description='Seed PRNG', style={'description_width': 'initial'},  # e.g. 'initial'  '120px'
        #     layout=Layout(width=constWidth),
        # )
        # self.prng_seed = BoundedIntText(
        #     min = 1,
        #     description='Seed',
        #     disabled=True,
        #     layout=Layout(width=constWidth),
        # )
        # def toggle_prng_cb(b):
        #     if (toggle_prng.value):
        #         self.prng_seed.disabled = False
        #     else:
        #         self.prng_seed.disabled = True

        # self.toggle_prng.observe(toggle_prng_cb)
        #prng_row = HBox([toggle_prng, prng_seed])

        self.toggle_svg = Checkbox(
            description='Cells',  # SVG
            layout=Layout(width='150px'))  # constWidth = '180px'
        # self.svg_t0 = BoundedFloatText (
        #     min=0,
        #     description='$T_0$',
        #     layout=Layout(width=constWidth),
        # )
        self.svg_interval = BoundedIntText(
            min=1,
            max=
            99999999,  # TODO: set max on all Bounded to avoid unwanted default
            description='every',
            layout=Layout(width='160px'),
        )
        self.mcds_interval = BoundedIntText(
            min=1,
            max=99999999,
            description='every',
            #            disabled=True,
            layout=Layout(width='160px'),
        )

        # don't let this be > mcds interval
        def svg_interval_cb(b):
            if (self.svg_interval.value > self.mcds_interval.value):
                self.svg_interval.value = self.mcds_interval.value

        self.svg_interval.observe(
            svg_interval_cb)  # BEWARE: when fill_gui, this sets value = 1 !

        # don't let this be < svg interval
        def mcds_interval_cb(b):
            if (self.mcds_interval.value < self.svg_interval.value):
                self.mcds_interval.value = self.svg_interval.value

        self.mcds_interval.observe(
            mcds_interval_cb)  # BEWARE: see warning above

        def toggle_svg_cb(b):
            if (self.toggle_svg.value):
                # self.svg_t0.disabled = False
                self.svg_interval.disabled = False
            else:
                # self.svg_t0.disabled = True
                self.svg_interval.disabled = True

        self.toggle_svg.observe(toggle_svg_cb)

        self.toggle_mcds = Checkbox(
            #     value=False,
            description='Subtrates',  # Full
            layout=Layout(width='180px'),
        )

        # self.mcds_t0 = FloatText(
        #     description='$T_0$',
        #     disabled=True,
        #     layout=Layout(width=constWidth),
        # )
        def toggle_mcds_cb(b):
            if (self.toggle_mcds.value):
                # self.mcds_t0.disabled = False #False
                self.mcds_interval.disabled = False
            else:
                # self.mcds_t0.disabled = True
                self.mcds_interval.disabled = True

        self.toggle_mcds.observe(toggle_mcds_cb)

        svg_mat_output_row = HBox([
            Label('Plots:'), self.toggle_svg,
            HBox([self.svg_interval, Label('min')]), self.toggle_mcds,
            HBox([self.mcds_interval, Label('min')])
        ])

        # to sync, do this
        # svg_mat_output_row = HBox( [Label('Plots:'), self.svg_interval, Label('min')])

        #write_config_row = HBox([write_config_button, write_config_file])
        #run_sim_row = HBox([run_button, run_command_str, kill_button])
        # run_sim_row = HBox([run_button, run_command_str])
        # run_sim_row = HBox([run_button.w])  # need ".w" for the custom RunCommand widget

        label_blankline = Label('')
        # toggle_2D_seed_row = HBox([toggle_prng, prng_seed])  # toggle2D

        box_layout = Layout(border='1px solid')
        #        domain_box = VBox([label_domain,x_row,y_row,z_row], layout=box_layout)
        domain_box = VBox([label_domain, x_row, y_row], layout=box_layout)
        self.tab = VBox([
            domain_box,
            #                         label_blankline,
            HBox([self.tmax, Label('min')]),
            self.omp_threads,
            svg_mat_output_row,
            #                         HBox([self.substrate[3], self.diffusion_coef[3], self.decay_rate[3] ]),
        ])  # output_dir, toggle_2D_seed_
#                         ], layout=tab_layout)  # output_dir, toggle_2D_seed_

# Populate the GUI widgets with values from the XML

    def fill_gui(self, xml_root):
        self.xmin.value = float(xml_root.find(".//x_min").text)
        self.xmax.value = float(xml_root.find(".//x_max").text)
        self.xdelta.value = float(xml_root.find(".//dx").text)

        self.ymin.value = float(xml_root.find(".//y_min").text)
        self.ymax.value = float(xml_root.find(".//y_max").text)
        self.ydelta.value = float(xml_root.find(".//dy").text)

        self.zmin.value = float(xml_root.find(".//z_min").text)
        self.zmax.value = float(xml_root.find(".//z_max").text)
        self.zdelta.value = float(xml_root.find(".//dz").text)

        self.tmax.value = float(xml_root.find(".//max_time").text)

        self.omp_threads.value = int(xml_root.find(".//omp_num_threads").text)

        if xml_root.find(".//full_data//enable").text.lower() == 'true':
            self.toggle_mcds.value = True
        else:
            self.toggle_mcds.value = False
        self.mcds_interval.value = int(
            xml_root.find(".//full_data//interval").text)

        # NOTE: do this *after* filling the mcds_interval, directly above, due to the callback/constraints on them
        if xml_root.find(".//SVG//enable").text.lower() == 'true':
            self.toggle_svg.value = True
        else:
            self.toggle_svg.value = False
        self.svg_interval.value = int(xml_root.find(".//SVG//interval").text)

    # Read values from the GUI widgets and generate/write a new XML
    def fill_xml(self, xml_root):
        # print('config.py fill_xml() !!!!!')
        # TODO: verify template .xml file exists!

        # TODO: verify valid type (numeric) and range?
        xml_root.find(".//x_min").text = str(self.xmin.value)
        xml_root.find(".//x_max").text = str(self.xmax.value)
        xml_root.find(".//dx").text = str(self.xdelta.value)
        xml_root.find(".//y_min").text = str(self.ymin.value)
        xml_root.find(".//y_max").text = str(self.ymax.value)
        xml_root.find(".//dy").text = str(self.ydelta.value)
        xml_root.find(".//z_min").text = str(self.zmin.value)
        xml_root.find(".//z_max").text = str(self.zmax.value)
        xml_root.find(".//dz").text = str(self.zdelta.value)

        xml_root.find(".//max_time").text = str(self.tmax.value)

        xml_root.find(".//omp_num_threads").text = str(self.omp_threads.value)

        xml_root.find(".//SVG").find(".//enable").text = str(
            self.toggle_svg.value)
        xml_root.find(".//SVG").find(".//interval").text = str(
            self.svg_interval.value)
        xml_root.find(".//full_data").find(".//enable").text = str(
            self.toggle_mcds.value)
        xml_root.find(".//full_data").find(".//interval").text = str(
            self.mcds_interval.value)

        #    user_details = ET.SubElement(root, "user_details")
        #    ET.SubElement(user_details, "PhysiCell_settings", name="version").text = "devel-version"
        #    ET.SubElement(user_details, "domain")
        #    ET.SubElement(user_details, "xmin").text = "-100"

        #    tree = ET.ElementTree(root)
        #    tree.write(write_config_file.value)
        #    tree.write("test.xml")

        # TODO: verify can write to this filename
#        tree.write(write_config_file.value)

    def get_num_svg_frames(self):
        if (self.toggle_svg.value):
            return int(self.tmax.value / self.svg_interval.value)
        else:
            return 0

    def get_num_substrate_frames(self):
        if (self.toggle_mcds.value):
            return int(self.tmax.value / self.mcds_interval.value)
        else:
            return 0
Exemple #17
0
class SubstrateTab(object):

    def __init__(self):
        
        self.output_dir = '.'
#        self.output_dir = 'tmpdir'

        # self.fig = plt.figure(figsize=(7.2,6))  # this strange figsize results in a ~square contour plot

        # initial value
        self.field_index = 4
        # self.field_index = self.mcds_field.value + 4

        # define dummy size of mesh (set in the tool's primary module)
        self.numx = 0
        self.numy = 0

        tab_height = '500px'
        constWidth = '180px'
        constWidth2 = '150px'
        tab_layout = Layout(width='900px',   # border='2px solid black',
                            height=tab_height, ) #overflow_y='scroll')

        max_frames = 1   
        self.mcds_plot = interactive(self.plot_substrate, frame=(0, max_frames), continuous_update=False)  
        svg_plot_size = '500px'  # small: controls the size of the tab height, not the plot (rf. figsize for that)
        svg_plot_size = '800px'  # medium
        svg_plot_size = '750px'  # medium
        self.mcds_plot.layout.width = svg_plot_size
        self.mcds_plot.layout.height = svg_plot_size

        self.max_frames = BoundedIntText(
            min=0, max=99999, value=max_frames,
            description='Max',
           layout=Layout(width='160px'),
        )
        self.max_frames.observe(self.update_max_frames)

        self.field_min_max = {'dummy': [0., 1.]}
        # hacky I know, but make a dict that's got (key,value) reversed from the dict in the Dropdown below
        self.field_dict = {0:'dummy'}

        self.mcds_field = Dropdown(
            options={'dummy': 0},
            value=0,
            #     description='Field',
           layout=Layout(width=constWidth)
        )
        # print("substrate __init__: self.mcds_field.value=",self.mcds_field.value)
#        self.mcds_field.observe(self.mcds_field_cb)
        self.mcds_field.observe(self.mcds_field_changed_cb)

        # self.field_cmap = Text(
        #     value='viridis',
        #     description='Colormap',
        #     disabled=True,
        #     layout=Layout(width=constWidth),
        # )
        self.field_cmap = Dropdown(
            options=['viridis', 'jet', 'YlOrRd'],
            value='viridis',
            #     description='Field',
           layout=Layout(width=constWidth)
        )
        #self.field_cmap.observe(self.plot_substrate)
#        self.field_cmap.observe(self.plot_substrate)
        self.field_cmap.observe(self.mcds_field_cb)

        self.cmap_fixed = Checkbox(
            description='Fix',
            disabled=False,
#           layout=Layout(width=constWidth2),
        )

        self.save_min_max= Button(
            description='Save', #style={'description_width': 'initial'},
            button_style='success',  # 'success', 'info', 'warning', 'danger' or ''
            tooltip='Save min/max for this substrate',
            disabled=True,
           layout=Layout(width='90px')
        )

        def save_min_max_cb(b):
#            field_name = self.mcds_field.options[]
#            field_name = next(key for key, value in self.mcds_field.options.items() if value == self.mcds_field.value)
            field_name = self.field_dict[self.mcds_field.value]
#            print(field_name)
#            self.field_min_max = {'oxygen': [0., 30.], 'glucose': [0., 1.], 'H+ ions': [0., 1.], 'ECM': [0., 1.], 'NP1': [0., 1.], 'NP2': [0., 1.]}
            self.field_min_max[field_name][0] = self.cmap_min.value
            self.field_min_max[field_name][1] = self.cmap_max.value
#            print(self.field_min_max)

        self.save_min_max.on_click(save_min_max_cb)

        self.cmap_min = FloatText(
            description='Min',
            value=0,
            step = 0.1,
            disabled=True,
            layout=Layout(width=constWidth2),
        )
        self.cmap_min.observe(self.mcds_field_cb)

        self.cmap_max = FloatText(
            description='Max',
            value=38,
            step = 0.1,
            disabled=True,
            layout=Layout(width=constWidth2),
        )
        self.cmap_max.observe(self.mcds_field_cb)

        def cmap_fixed_cb(b):
            if (self.cmap_fixed.value):
                self.cmap_min.disabled = False
                self.cmap_max.disabled = False
                self.save_min_max.disabled = False
            else:
                self.cmap_min.disabled = True
                self.cmap_max.disabled = True
                self.save_min_max.disabled = True
#            self.mcds_field_cb()

        self.cmap_fixed.observe(cmap_fixed_cb)

        field_cmap_row2 = HBox([self.field_cmap, self.cmap_fixed])

#        field_cmap_row3 = HBox([self.save_min_max, self.cmap_min, self.cmap_max])
        items_auto = [
            self.save_min_max, #layout=Layout(flex='3 1 auto', width='auto'),
            self.cmap_min, 
            self.cmap_max,  
         ]
        box_layout = Layout(display='flex',
                    flex_flow='row',
                    align_items='stretch',
                    width='80%')
        field_cmap_row3 = Box(children=items_auto, layout=box_layout)

#        field_cmap_row3 = Box([self.save_min_max, self.cmap_min, self.cmap_max])

        # mcds_tab = widgets.VBox([mcds_dir, mcds_plot, mcds_play], layout=tab_layout)
        mcds_params = VBox([self.mcds_field, field_cmap_row2, field_cmap_row3, self.max_frames])  # mcds_dir
#        mcds_params = VBox([self.mcds_field, field_cmap_row2, field_cmap_row3,])  # mcds_dir

#        self.tab = HBox([mcds_params, self.mcds_plot], layout=tab_layout)
#        self.tab = HBox([mcds_params, self.mcds_plot])

        help_label = Label('select slider: drag or left/right arrows')
        row1 = Box([help_label, Box( [self.max_frames, self.mcds_field, self.field_cmap], layout=Layout(border='0px solid black',
                            width='50%',
                            height='',
                            align_items='stretch',
                            flex_direction='row',
                            display='flex'))] )
        row2 = Box([self.cmap_fixed, self.cmap_min, self.cmap_max], layout=Layout(border='0px solid black',
                            width='50%',
                            height='',
                            align_items='stretch',
                            flex_direction='row',
                            display='flex'))
        if (hublib_flag):
            self.download_button = Download('mcds.zip', style='warning', icon='cloud-download', 
                                                tooltip='Download data', cb=self.download_cb)
            download_row = HBox([self.download_button.w, Label("Download all substrate data (browser must allow pop-ups).")])

    #        self.tab = VBox([row1, row2, self.mcds_plot])
            self.tab = VBox([row1, row2, self.mcds_plot, download_row])
        else:
            # self.tab = VBox([row1, row2])
            self.tab = VBox([row1, row2, self.mcds_plot])

    #---------------------------------------------------
    def update_dropdown_fields(self, data_dir):
        # print('update_dropdown_fields called --------')
        self.output_dir = data_dir
        tree = None
        try:
            fname = os.path.join(self.output_dir, "initial.xml")
            tree = ET.parse(fname)
            xml_root = tree.getroot()
        except:
            print("Cannot open ",fname," to read info, e.g., names of substrate fields.")
            return

        xml_root = tree.getroot()
        self.field_min_max = {}
        self.field_dict = {}
        dropdown_options = {}
        uep = xml_root.find('.//variables')
        comment_str = ""
        field_idx = 0
        if (uep):
            for elm in uep.findall('variable'):
                # print("-----> ",elm.attrib['name'])
                self.field_min_max[elm.attrib['name']] = [0., 1.]
                self.field_dict[field_idx] = elm.attrib['name']
                dropdown_options[elm.attrib['name']] = field_idx
                field_idx += 1

#        constWidth = '180px'
        # print('options=',dropdown_options)
        self.mcds_field.value=0
        self.mcds_field.options=dropdown_options
#         self.mcds_field = Dropdown(
# #            options={'oxygen': 0, 'glucose': 1},
#             options=dropdown_options,
#             value=0,
#             #     description='Field',
#            layout=Layout(width=constWidth)
#         )

    def update_max_frames_expected(self, value):  # called when beginning an interactive Run
        self.max_frames.value = value  # assumes naming scheme: "snapshot%08d.svg"
        self.mcds_plot.children[0].max = self.max_frames.value

#    def update(self, rdir):
    def update(self, rdir=''):
        # with debug_view:
        #     print("substrates: update rdir=", rdir)        

        if rdir:
            self.output_dir = rdir

        all_files = sorted(glob.glob(os.path.join(self.output_dir, 'output*.xml')))
        if len(all_files) > 0:
            last_file = all_files[-1]
            self.max_frames.value = int(last_file[-12:-4])  # assumes naming scheme: "snapshot%08d.svg"

        # with debug_view:
        #     print("substrates: added %s files" % len(all_files))


        # self.output_dir = rdir
        # if rdir == '':
        #     # self.max_frames.value = 0
        #     tmpdir = os.path.abspath('tmpdir')
        #     self.output_dir = tmpdir
        #     all_files = sorted(glob.glob(os.path.join(tmpdir, 'output*.xml')))
        #     if len(all_files) > 0:
        #         last_file = all_files[-1]
        #         self.max_frames.value = int(last_file[-12:-4])  # assumes naming scheme: "output%08d.xml"
        #         self.mcds_plot.update()
        #     return

        # all_files = sorted(glob.glob(os.path.join(rdir, 'output*.xml')))
        # if len(all_files) > 0:
        #     last_file = all_files[-1]
        #     self.max_frames.value = int(last_file[-12:-4])  # assumes naming scheme: "output%08d.xml"
        #     self.mcds_plot.update()

    def download_cb(self):
        file_xml = os.path.join(self.output_dir, '*.xml')
        file_mat = os.path.join(self.output_dir, '*.mat')
        # print('zip up all ',file_str)
        with zipfile.ZipFile('mcds.zip', 'w') as myzip:
            for f in glob.glob(file_xml):
                myzip.write(f, os.path.basename(f)) # 2nd arg avoids full filename path in the archive
            for f in glob.glob(file_mat):
                myzip.write(f, os.path.basename(f))

    def update_max_frames(self,_b):
        self.mcds_plot.children[0].max = self.max_frames.value

    def mcds_field_changed_cb(self, b):
        # print("mcds_field_changed_cb: self.mcds_field.value=",self.mcds_field.value)
        if (self.mcds_field.value == None):
            return
        self.field_index = self.mcds_field.value + 4

        field_name = self.field_dict[self.mcds_field.value]
#        print('mcds_field_cb: '+field_name)
        self.cmap_min.value = self.field_min_max[field_name][0]
        self.cmap_max.value = self.field_min_max[field_name][1]
        self.mcds_plot.update()

    def mcds_field_cb(self, b):
        #self.field_index = self.mcds_field.value
#        self.field_index = self.mcds_field.options.index(self.mcds_field.value) + 4
#        self.field_index = self.mcds_field.options[self.mcds_field.value]
        self.field_index = self.mcds_field.value + 4

        # field_name = self.mcds_field.options[self.mcds_field.value]
        # self.cmap_min.value = self.field_min_max[field_name][0]  # oxygen, etc
        # self.cmap_max.value = self.field_min_max[field_name][1]  # oxygen, etc

#        self.field_index = self.mcds_field.value + 4

#        print('field_index=',self.field_index)
        self.mcds_plot.update()

    def plot_substrate(self, frame):
        # global current_idx, axes_max, gFileId, field_index
        fname = "output%08d_microenvironment0.mat" % frame
        xml_fname = "output%08d.xml" % frame
        # fullname = output_dir_str + fname

#        fullname = fname
        full_fname = os.path.join(self.output_dir, fname)
        full_xml_fname = os.path.join(self.output_dir, xml_fname)
#        self.output_dir = '.'

#        if not os.path.isfile(fullname):
        if not os.path.isfile(full_fname):
            print("Once output files are generated, click the slider.")  # No:  output00000000_microenvironment0.mat
            return

#        tree = ET.parse(xml_fname)
        tree = ET.parse(full_xml_fname)
        xml_root = tree.getroot()
        mins= round(int(float(xml_root.find(".//current_time").text)))  # TODO: check units = mins
        hrs = int(mins/60)
        days = int(hrs/24)
        title_str = '%dd, %dh, %dm' % (int(days),(hrs%24), mins - (hrs*60))


        info_dict = {}
#        scipy.io.loadmat(fullname, info_dict)
        scipy.io.loadmat(full_fname, info_dict)
        M = info_dict['multiscale_microenvironment']
        #     global_field_index = int(mcds_field.value)
        #     print('plot_substrate: field_index =',field_index)
        f = M[self.field_index, :]   # 4=tumor cells field, 5=blood vessel density, 6=growth substrate
        # plt.clf()
        # my_plot = plt.imshow(f.reshape(400,400), cmap='jet', extent=[0,20, 0,20])
    
        # self.fig = plt.figure(figsize=(7.2,6))  # this strange figsize results in a ~square contour plot
        self.fig = plt.figure(figsize=(24.0,20))  # this strange figsize results in a ~square contour plot
        # self.fig = plt.figure(figsize=(28.8,24))  # this strange figsize results in a ~square contour plot
        #     fig.set_tight_layout(True)
        #     ax = plt.axes([0, 0.05, 0.9, 0.9 ]) #left, bottom, width, height
        #     ax = plt.axes([0, 0.0, 1, 1 ])
        #     cmap = plt.cm.viridis # Blues, YlOrBr, ...
        #     im = ax.imshow(f.reshape(100,100), interpolation='nearest', cmap=cmap, extent=[0,20, 0,20])
        #     ax.grid(False)

        # print("substrates.py: ------- numx, numy = ", self.numx, self.numy )
        if (self.numx == 0):   # need to parse vals from the config.xml
            fname = os.path.join(self.output_dir, "config.xml")
            tree = ET.parse(fname)
            xml_root = tree.getroot()
            xmin = float(xml_root.find(".//x_min").text)
            xmax = float(xml_root.find(".//x_max").text)
            dx = float(xml_root.find(".//dx").text)
            ymin = float(xml_root.find(".//y_min").text)
            ymax = float(xml_root.find(".//y_max").text)
            dy = float(xml_root.find(".//dy").text)
            self.numx =  math.ceil( (xmax - xmin) / dx)
            self.numy =  math.ceil( (ymax - ymin) / dy)

        xgrid = M[0, :].reshape(self.numy, self.numx)
        ygrid = M[1, :].reshape(self.numy, self.numx)

        num_contours = 15
        levels = MaxNLocator(nbins=num_contours).tick_values(self.cmap_min.value, self.cmap_max.value)
        contour_ok = True
        if (self.cmap_fixed.value):
            try:
                my_plot = plt.contourf(xgrid, ygrid, M[self.field_index, :].reshape(self.numy, self.numx), levels=levels, extend='both', cmap=self.field_cmap.value)
            except:
                contour_ok = False
                # print('got error on contourf 1.')
        else:    
            try:
                my_plot = plt.contourf(xgrid, ygrid, M[self.field_index, :].reshape(self.numy,self.numx), num_contours, cmap=self.field_cmap.value)
            except:
                contour_ok = False
                # print('got error on contourf 2.')

        if (contour_ok):
            plt.title(title_str)
            plt.colorbar(my_plot)
        axes_min = 0
        axes_max = 2000
Exemple #18
0
class SubstrateTab(object):
    def __init__(self):

        self.output_dir = '.'
        #        self.output_dir = 'tmpdir'

        # self.fig = plt.figure(figsize=(7.2,6))  # this strange figsize results in a ~square contour plot

        # initial value
        self.field_index = 4
        # self.field_index = self.mcds_field.value + 4

        tab_height = '500px'
        constWidth = '180px'
        constWidth2 = '150px'
        tab_layout = Layout(
            width='900px',  # border='2px solid black',
            height=tab_height,
        )  #overflow_y='scroll')

        max_frames = 1
        self.mcds_plot = interactive(self.plot_substrate,
                                     frame=(0, max_frames),
                                     continuous_update=False)
        svg_plot_size = '700px'
        self.mcds_plot.layout.width = svg_plot_size
        self.mcds_plot.layout.height = svg_plot_size

        self.max_frames = BoundedIntText(
            min=0,
            max=99999,
            value=max_frames,
            description='Max',
            layout=Layout(width='160px'),
        )
        self.max_frames.observe(self.update_max_frames)

        self.field_min_max = {'dummy': [0., 1.]}
        # hacky I know, but make a dict that's got (key,value) reversed from the dict in the Dropdown below
        self.field_dict = {0: 'dummy'}

        self.mcds_field = Dropdown(
            options={'dummy': 0},
            value=0,
            #     description='Field',
            layout=Layout(width=constWidth))
        # print("substrate __init__: self.mcds_field.value=",self.mcds_field.value)
        #        self.mcds_field.observe(self.mcds_field_cb)
        self.mcds_field.observe(self.mcds_field_changed_cb)

        # self.field_cmap = Text(
        #     value='viridis',
        #     description='Colormap',
        #     disabled=True,
        #     layout=Layout(width=constWidth),
        # )
        self.field_cmap = Dropdown(
            options=['viridis', 'jet', 'YlOrRd'],
            value='viridis',
            #     description='Field',
            layout=Layout(width=constWidth))
        #self.field_cmap.observe(self.plot_substrate)
        #        self.field_cmap.observe(self.plot_substrate)
        self.field_cmap.observe(self.mcds_field_cb)

        self.cmap_fixed = Checkbox(
            description='Fix',
            disabled=False,
            #           layout=Layout(width=constWidth2),
        )

        self.save_min_max = Button(
            description='Save',  #style={'description_width': 'initial'},
            button_style=
            'success',  # 'success', 'info', 'warning', 'danger' or ''
            tooltip='Save min/max for this substrate',
            disabled=True,
            layout=Layout(width='90px'))

        def save_min_max_cb(b):
            #            field_name = self.mcds_field.options[]
            #            field_name = next(key for key, value in self.mcds_field.options.items() if value == self.mcds_field.value)
            field_name = self.field_dict[self.mcds_field.value]
            #            print(field_name)
            #            self.field_min_max = {'oxygen': [0., 30.], 'glucose': [0., 1.], 'H+ ions': [0., 1.], 'ECM': [0., 1.], 'NP1': [0., 1.], 'NP2': [0., 1.]}
            self.field_min_max[field_name][0] = self.cmap_min.value
            self.field_min_max[field_name][1] = self.cmap_max.value
#            print(self.field_min_max)

        self.save_min_max.on_click(save_min_max_cb)

        self.cmap_min = FloatText(
            description='Min',
            value=0,
            step=0.1,
            disabled=True,
            layout=Layout(width=constWidth2),
        )
        self.cmap_min.observe(self.mcds_field_cb)

        self.cmap_max = FloatText(
            description='Max',
            value=38,
            step=0.1,
            disabled=True,
            layout=Layout(width=constWidth2),
        )
        self.cmap_max.observe(self.mcds_field_cb)

        def cmap_fixed_cb(b):
            if (self.cmap_fixed.value):
                self.cmap_min.disabled = False
                self.cmap_max.disabled = False
                self.save_min_max.disabled = False
            else:
                self.cmap_min.disabled = True
                self.cmap_max.disabled = True
                self.save_min_max.disabled = True
#            self.mcds_field_cb()

        self.cmap_fixed.observe(cmap_fixed_cb)

        field_cmap_row2 = HBox([self.field_cmap, self.cmap_fixed])

        #        field_cmap_row3 = HBox([self.save_min_max, self.cmap_min, self.cmap_max])
        items_auto = [
            self.save_min_max,  #layout=Layout(flex='3 1 auto', width='auto'),
            self.cmap_min,
            self.cmap_max,
        ]
        box_layout = Layout(display='flex',
                            flex_flow='row',
                            align_items='stretch',
                            width='80%')
        field_cmap_row3 = Box(children=items_auto, layout=box_layout)

        #        field_cmap_row3 = Box([self.save_min_max, self.cmap_min, self.cmap_max])

        # mcds_tab = widgets.VBox([mcds_dir, mcds_plot, mcds_play], layout=tab_layout)
        mcds_params = VBox([
            self.mcds_field, field_cmap_row2, field_cmap_row3, self.max_frames
        ])  # mcds_dir
        #        mcds_params = VBox([self.mcds_field, field_cmap_row2, field_cmap_row3,])  # mcds_dir

        #        self.tab = HBox([mcds_params, self.mcds_plot], layout=tab_layout)
        #        self.tab = HBox([mcds_params, self.mcds_plot])

        help_label = Label('select slider: drag or left/right arrows')
        row1 = Box([
            help_label,
            Box([self.max_frames, self.mcds_field, self.field_cmap],
                layout=Layout(border='0px solid black',
                              width='50%',
                              height='',
                              align_items='stretch',
                              flex_direction='row',
                              display='flex'))
        ])
        row2 = Box([self.cmap_fixed, self.cmap_min, self.cmap_max],
                   layout=Layout(border='0px solid black',
                                 width='50%',
                                 height='',
                                 align_items='stretch',
                                 flex_direction='row',
                                 display='flex'))
        self.tab = VBox([row1, row2, self.mcds_plot])

    #---------------------------------------------------
    def update_dropdown_fields(self, data_dir):
        # print('update_dropdown_fields called --------')
        self.output_dir = data_dir
        tree = None
        try:
            fname = os.path.join(self.output_dir, "initial.xml")
            tree = ET.parse(fname)
#            return
        except:
            print("Cannot open ", fname, " to get names of substrate fields.")
            return

        xml_root = tree.getroot()
        self.field_min_max = {}
        self.field_dict = {}
        dropdown_options = {}
        uep = xml_root.find('.//variables')
        comment_str = ""
        field_idx = 0
        if (uep):
            for elm in uep.findall('variable'):
                # print("-----> ",elm.attrib['name'])
                self.field_min_max[elm.attrib['name']] = [0., 1.]
                self.field_dict[field_idx] = elm.attrib['name']
                dropdown_options[elm.attrib['name']] = field_idx
                field_idx += 1

#        constWidth = '180px'
# print('options=',dropdown_options)
        self.mcds_field.value = 0
        self.mcds_field.options = dropdown_options
#         self.mcds_field = Dropdown(
# #            options={'oxygen': 0, 'glucose': 1},
#             options=dropdown_options,
#             value=0,
#             #     description='Field',
#            layout=Layout(width=constWidth)
#         )

    def update_max_frames_expected(
            self, value):  # called when beginning an interactive Run
        self.max_frames.value = value  # assumes naming scheme: "snapshot%08d.svg"
        self.mcds_plot.children[0].max = self.max_frames.value

    def update(self, rdir):
        self.output_dir = rdir
        if rdir == '':
            # self.max_frames.value = 0
            tmpdir = os.path.abspath('tmpdir')
            self.output_dir = tmpdir
            all_files = sorted(glob.glob(os.path.join(tmpdir, 'output*.xml')))
            if len(all_files) > 0:
                last_file = all_files[-1]
                self.max_frames.value = int(
                    last_file[-12:-4]
                )  # assumes naming scheme: "output%08d.xml"
                self.mcds_plot.update()
            return

        all_files = sorted(glob.glob(os.path.join(rdir, 'output*.xml')))
        if len(all_files) > 0:
            last_file = all_files[-1]
            self.max_frames.value = int(
                last_file[-12:-4])  # assumes naming scheme: "output%08d.xml"
            self.mcds_plot.update()

    def update_max_frames(self, _b):
        self.mcds_plot.children[0].max = self.max_frames.value

    def mcds_field_changed_cb(self, b):
        # print("mcds_field_changed_cb: self.mcds_field.value=",self.mcds_field.value)
        if (self.mcds_field.value == None):
            return
        self.field_index = self.mcds_field.value + 4

        field_name = self.field_dict[self.mcds_field.value]
        #        print('mcds_field_cb: '+field_name)
        self.cmap_min.value = self.field_min_max[field_name][0]
        self.cmap_max.value = self.field_min_max[field_name][1]
        self.mcds_plot.update()

    def mcds_field_cb(self, b):
        #self.field_index = self.mcds_field.value
        #        self.field_index = self.mcds_field.options.index(self.mcds_field.value) + 4
        #        self.field_index = self.mcds_field.options[self.mcds_field.value]
        self.field_index = self.mcds_field.value + 4

        # field_name = self.mcds_field.options[self.mcds_field.value]
        # self.cmap_min.value = self.field_min_max[field_name][0]  # oxygen, etc
        # self.cmap_max.value = self.field_min_max[field_name][1]  # oxygen, etc

        #        self.field_index = self.mcds_field.value + 4

        #        print('field_index=',self.field_index)
        self.mcds_plot.update()

    def plot_substrate(self, frame):
        # global current_idx, axes_max, gFileId, field_index
        fname = "output%08d_microenvironment0.mat" % frame
        xml_fname = "output%08d.xml" % frame
        # fullname = output_dir_str + fname

        #        fullname = fname
        full_fname = os.path.join(self.output_dir, fname)
        full_xml_fname = os.path.join(self.output_dir, xml_fname)
        #        self.output_dir = '.'

        #        if not os.path.isfile(fullname):
        if not os.path.isfile(full_fname):
            #            print("File does not exist: ", full_fname)
            #            print("No: ", full_fname)
            print("Once output files are generated, click the slider."
                  )  # No:  output00000000_microenvironment0.mat

            return

#        tree = ET.parse(xml_fname)
        tree = ET.parse(full_xml_fname)
        xml_root = tree.getroot()
        mins = round(int(float(xml_root.find(
            ".//current_time").text)))  # TODO: check units = mins
        hrs = int(mins / 60)
        days = int(hrs / 24)
        title_str = '%dd, %dh, %dm' % (int(days),
                                       (hrs % 24), mins - (hrs * 60))

        info_dict = {}
        #        scipy.io.loadmat(fullname, info_dict)
        scipy.io.loadmat(full_fname, info_dict)
        M = info_dict['multiscale_microenvironment']
        #     global_field_index = int(mcds_field.value)
        #     print('plot_substrate: field_index =',field_index)
        f = M[
            self.
            field_index, :]  # 4=tumor cells field, 5=blood vessel density, 6=growth substrate
        # plt.clf()
        # my_plot = plt.imshow(f.reshape(400,400), cmap='jet', extent=[0,20, 0,20])

        self.fig = plt.figure(figsize=(
            7.2, 6))  # this strange figsize results in a ~square contour plot
        #     fig.set_tight_layout(True)
        #     ax = plt.axes([0, 0.05, 0.9, 0.9 ]) #left, bottom, width, height
        #     ax = plt.axes([0, 0.0, 1, 1 ])
        #     cmap = plt.cm.viridis # Blues, YlOrBr, ...
        #     im = ax.imshow(f.reshape(100,100), interpolation='nearest', cmap=cmap, extent=[0,20, 0,20])
        #     ax.grid(False)

        N = int(math.sqrt(len(M[0, :])))
        grid2D = M[0, :].reshape(N, N)
        xvec = grid2D[0, :]

        num_contours = 15
        #        levels = MaxNLocator(nbins=10).tick_values(vmin, vmax)
        levels = MaxNLocator(nbins=num_contours).tick_values(
            self.cmap_min.value, self.cmap_max.value)
        if (self.cmap_fixed.value):
            my_plot = plt.contourf(xvec,
                                   xvec,
                                   M[self.field_index, :].reshape(N, N),
                                   levels=levels,
                                   extend='both',
                                   cmap=self.field_cmap.value)
        else:
            #        my_plot = plt.contourf(xvec, xvec, M[self.field_index, :].reshape(N,N), num_contours, cmap=self.field_cmap.value)
            my_plot = plt.contourf(xvec,
                                   xvec,
                                   M[self.field_index, :].reshape(N, N),
                                   num_contours,
                                   cmap=self.field_cmap.value)

        plt.title(title_str)
        plt.colorbar(my_plot)
        axes_min = 0
        axes_max = 2000
class PhysiBoSSTab(object):
    def __init__(self):
        # tab_height = '520px'
        # tab_layout = Layout(width='900px',   # border='2px solid black',
        #                     height=tab_height, overflow_y='scroll')

        self.output_dir = '.'
        self.figsize_width = 15.0  # allow extra for colormap
        self.figsize_height = 8

        constWidth = '180px'

        #        self.fig = plt.figure(figsize=(6, 6))
        # self.fig = plt.figure(figsize=(7, 7))

        config_file = "data/PhysiCell_settings.xml"

        self.cell_lines = {}
        self.cell_lines_by_name = {}
        self.cell_lines_array = ["All"]

        if os.path.isfile(config_file):

            try:
                tree = ET.parse(config_file)
            except:
                print("Cannot parse", config_file, "- check it's XML syntax.")
                return

            root = tree.getroot()
            uep = root.find(
                './/cell_definitions')  # find unique entry point (uep)
            for child in uep.findall('cell_definition'):
                self.cell_lines[int(child.attrib["ID"])] = child.attrib["name"]
                self.cell_lines_by_name[child.attrib["name"]] = int(
                    child.attrib["ID"])
                self.cell_lines_array.append(child.attrib["name"])
                # print(child.attrib['name'])
        else:
            print("config.xml does not exist")

        max_frames = 0
        self.svg_plot = interactive(self.create_area_chart,
                                    frame=(0, max_frames),
                                    percentage=(0.0, 10.0),
                                    total=False,
                                    cell_line=self.cell_lines_array,
                                    continuous_update=False)
        plot_size = '500px'  # small: controls the size of the tab height, not the plot (rf. figsize for that)
        plot_size = '700px'  # medium
        plot_size = '750px'  # medium
        self.svg_plot.layout.width = '1000px'
        self.svg_plot.layout.height = '700px'
        self.use_defaults = True

        self.axes_min = 0.0
        self.axes_max = 2000  # hmm, this can change (TODO?)

        self.max_frames = BoundedIntText(
            min=0,
            max=99999,
            value=max_frames,
            description='Max',
            layout=Layout(width='160px'),
            #            layout=Layout(flex='1 1 auto', width='auto'),  #Layout(width='160px'),
        )
        self.max_frames.observe(self.update_max_frames)

        items_auto = [
            Label('select slider: drag or left/right arrows'),
            self.max_frames,
        ]

        box_layout = Layout(display='flex',
                            flex_flow='row',
                            align_items='stretch',
                            width='900px')
        row1 = Box(children=items_auto, layout=box_layout)

        self.tab = VBox([row1, self.svg_plot])
        self.count_dict = {}
        self.file_dict = {}
        self.cells_indexes = np.zeros((0))
        self.up_to_frame = 0

    def update(self, rdir=''):
        # with debug_view:
        #     print("SVG: update rdir=", rdir)

        if rdir:
            self.output_dir = rdir

        all_files = sorted(
            glob.glob(os.path.join(self.output_dir, 'snapshot*.svg')))
        if len(all_files) > 0:
            last_file = all_files[-1]
            self.max_frames.value = int(
                last_file[-12:-4])  # assumes naming scheme: "snapshot%08d.svg"

        # self.create_dict(self.max_frames.value, self.output_dir)
        # self.state_counter(self.max_frames.value)

        # with debug_view:
        #     print("SVG: added %s files" % len(all_files))

    def update_max_frames(self, _b):
        self.svg_plot.children[0].max = self.max_frames.value

    def create_dict(self, number_of_files, folder):
        "create a dictionary with the states file in the folder 'output', half of the dict is used to calculate the percentage of the node, the other half is for the states"

        if number_of_files > 0:
            for i in range(0, number_of_files):
                if "state_step{0}".format(i) not in self.file_dict.keys():
                    states_dict = {}
                    with open(os.path.join(self.output_dir,
                                           'states_%08u.csv' % i),
                              newline='') as csvfile:
                        states_reader = csv.reader(csvfile, delimiter=',')

                        for row in states_reader:
                            if row[0] != 'ID':
                                states_dict[int(row[0])] = row[1]

                    self.file_dict["state_step{0}".format(i)] = states_dict

    def state_counter(self, number_of_files, percentage, cell_indexes,
                      cell_line):
        "create a dict with the states of the network, it can be used to print states pie chart"
        self.count_dict = {}
        temp_dict = {}
        max_cell = 0
        if number_of_files > 0:
            for i in range(0, number_of_files):
                state_list = []
                for key in self.file_dict["state_step{0}".format(i)]:
                    if cell_line == 'All' or self.cells_indexes[
                            key] == self.cell_lines_by_name[cell_line]:
                        state_list.append(
                            self.file_dict["state_step{0}".format(i)][key])
                state_counts = Counter(state_list)
                max_cell = max_cell + sum(state_counts.values())

                temp_dict["state_count{0}".format(i)] = state_counts
            self.count_dict = self.filter_states(max_cell, temp_dict,
                                                 percentage)

    def create_cell_indexes(self, frame, cell_line):

        for i in range(self.up_to_frame, frame):
            fname = "output%08d_cells_physicell.mat" % i
            full_fname = os.path.join(self.output_dir, fname)

            if not os.path.isfile(full_fname):
                print("Once output files are generated, click the slider."
                      )  # No:  output00000000_microenvironment0.mat
                return

            info_dict = {}
            scipy.io.loadmat(full_fname, info_dict)
            M = info_dict['cells'][[0, 5], :].astype(int)

            self.cells_indexes.resize((max(self.cells_indexes.shape[0],
                                           M[0, :].max(axis=0) + 1)))
            self.cells_indexes[M[0, :]] = M[1, :]

        self.up_to_frame = frame
        return self.cells_indexes

    def create_area_chart(self,
                          frame=None,
                          total=False,
                          percentage=(0.0, 100.0),
                          cell_line="All"):
        "plot an area chart with the evolution of the network states during the simulation"

        cells_indexes = None
        if cell_line != "All":
            cells_indexes = self.create_cell_indexes(frame, cell_line)
            if np.sum(
                    cells_indexes == self.cell_lines_by_name[cell_line]) == 0:
                print("There are no %s cells." % cell_line)
                return

        self.create_dict(frame, self.output_dir)
        self.state_counter(frame, percentage, cells_indexes, cell_line)

        state_list = []
        all_state = []
        a = []
        for k in self.count_dict:
            state_list.append([
                key for key, value in self.count_dict[k].items() if value > 0
            ])
            for l in state_list:
                for state in l:
                    all_state.append(state)
        all_state = list(dict.fromkeys(all_state))

        for state_count in self.count_dict:
            b = []
            for states in all_state:
                try:
                    b.append(self.count_dict[state_count][states])
                except:
                    b.append(0)
            a.append(b)
        a = np.array(a)
        #print(a)
        a = np.transpose(a)
        if not total:
            percent = a / a.sum(axis=0).astype(float) * 100
        else:
            percent = a
        x = np.arange(len(self.count_dict))
        self.fig = plt.figure(figsize=(self.figsize_width,
                                       self.figsize_height))
        ax = self.fig.add_subplot(111)
        ax.stackplot(x, percent, labels=all_state)
        ax.legend(labels=all_state,
                  loc='upper center',
                  bbox_to_anchor=(0.5, -0.05),
                  shadow=True,
                  ncol=2)
        # ax.legend(labels=all_state, bbox_to_anchor=(1.05, 1), loc='lower center', borderaxespad=0.)
        if not total:
            ax.set_ylabel('Percent (%)')
        else:
            ax.set_ylabel("Total")
        ax.margins(0, 0)  # Set margins to avoid "whitespace"

        # plt.show()

    def filter_states(self, max_cell, all_counts, percentage):
        """max_cell = 0
        all_counts = {}
        for i in range (0, number_of_files):
            state_list = []
            for key in file_dict["state_step{0}".format(i)]:
                state_list.append(file_dict["state_step{0}".format(i)][key])
            state_counts = Counter(state_list)
            max_cell = max_cell + sum(state_counts.values())
            all_counts[i] = state_counts"""

        copy_all_counts = copy.deepcopy(all_counts)

        state_list = []
        all_state = []
        for k in all_counts:
            state_list.append(list(all_counts[k].keys()))
            for l in state_list:
                for state in l:
                    all_state.append(state)
        all_state = list(dict.fromkeys(all_state))

        banned_list = []
        for state in all_state:
            a = 0
            for i in all_counts.keys():
                try:
                    a = a + all_counts[i][state]
                except:
                    a = a + 0
            if (a < (percentage / 100) * max_cell):
                banned_list.append(state)
                for i in all_counts.keys():
                    del all_counts[i][state]

        for i in all_counts.keys():
            b = 0
            for state in banned_list:
                try:
                    b = b + copy_all_counts[i][state]
                except:
                    b = b + 0
            all_counts[i]["others"] = b

        return all_counts
Exemple #20
0
class CellsTab(object):
    def __init__(self, user_tab):
        # tab_height = '520px'
        # tab_layout = Layout(width='900px',   # border='2px solid black',
        #                     height=tab_height, overflow_y='scroll')

        self.output_dir = '.'

        constWidth = '380px'
        constWidth = '180px'

        #        self.fig = plt.figure(figsize=(6, 6))
        # self.fig = plt.figure(figsize=(7, 7))

        max_frames = 1
        self.cells_plot = interactive(self.plot_cells,
                                      frame=(0, max_frames),
                                      continuous_update=False)

        # https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20List.html#Play-(Animation)-widget
        # play = widgets.Play(
        # #     interval=10,
        #     value=50,
        #     min=0,
        #     max=100,
        #     step=1,
        #     description="Press play",
        #     disabled=False
        # )
        # slider = widgets.IntSlider()
        # widgets.jslink((play, 'value'), (slider, 'value'))
        # widgets.HBox([play, slider])

        # "plot_size" controls the size of the tab height, not the plot (rf. figsize for that)
        plot_size = '600px'  # medium
        plot_size = '700px'  # medium
        plot_size = '750px'  # medium
        plot_size = '500px'  # small:
        self.cells_plot.layout.width = plot_size
        self.cells_plot.layout.height = plot_size
        self.use_defaults = True
        self.show_nucleus = 1  # 0->False, 1->True in Checkbox!
        self.show_edge = 1  # 0->False, 1->True in Checkbox!
        self.show_tracks = 0  # 0->False, 1->True in Checkbox!
        self.trackd = {
        }  # dictionary to hold cell IDs and their tracks: (x,y) pairs
        # self.scale_radius = 1.0
        # self.axes_min = 0
        # self.axes_max = 2000
        self.axes_min = -200.0
        self.axes_max = 200.  # TODO: get from input file

        self.max_frames = BoundedIntText(
            min=0,
            max=99999,
            value=max_frames,
            description='Max',
            layout=Layout(width='160px'),
            #            layout=Layout(flex='1 1 auto', width='auto'),  #Layout(width='160px'),
        )
        self.max_frames.observe(self.update_max_frames)

        self.show_nucleus_checkbox = Checkbox(
            description='nucleus',
            value=True,
            disabled=False,
            layout=Layout(width=constWidth),
            #            layout=Layout(flex='1 1 auto', width='auto'),  #Layout(width='160px'),
        )
        self.show_nucleus_checkbox.observe(self.show_nucleus_cb)

        self.show_edge_checkbox = Checkbox(
            description='edge',
            value=True,
            disabled=False,
            layout=Layout(width=constWidth),
            #            layout=Layout(flex='1 1 auto', width='auto'),  #Layout(width='160px'),
        )
        self.show_edge_checkbox.observe(self.show_edge_cb)

        self.show_tracks_checkbox = Checkbox(
            description='tracks',
            value=True,
            disabled=False,
            layout=Layout(width=constWidth),
            #            layout=Layout(flex='1 1 auto', width='auto'),  #Layout(width='160px'),
        )
        # self.show_tracks_checkbox.observe(self.show_tracks_cb)

        #        row1 = HBox([Label('(select slider: drag or left/right arrows)'),
        #            self.max_frames, VBox([self.show_nucleus_checkbox, self.show_edge_checkbox])])
        #            self.max_frames, self.show_nucleus_checkbox], layout=Layout(width='500px'))

        #        self.tab = VBox([row1,self.cells_plot], layout=tab_layout)

        items_auto = [
            Label('select slider: drag or left/right arrows'),
            self.max_frames,
            # self.show_nucleus_checkbox,
            # self.show_edge_checkbox,
            # self.show_tracks_checkbox,
        ]
        #row1 = HBox([Label('(select slider: drag or left/right arrows)'),
        #            max_frames, show_nucleus_checkbox, show_edge_checkbox],
        #            layout=Layout(width='800px'))
        box_layout = Layout(display='flex',
                            flex_flow='row',
                            align_items='stretch',
                            width='70%')
        row1 = Box(children=items_auto, layout=box_layout)

        #     if (hublib_flag):
        #         self.download_button = Download('svg.zip', style='warning', icon='cloud-download',
        #                                         tooltip='You need to allow pop-ups in your browser', cb=self.download_cb)
        #         download_row = HBox([self.download_button.w, Label("Download all cell plots (browser must allow pop-ups).")])
        # #        self.tab = VBox([row1, self.cells_plot, self.download_button.w], layout=tab_layout)
        # #        self.tab = VBox([row1, self.cells_plot, self.download_button.w])
        #         self.tab = VBox([row1, self.cells_plot, download_row])
        #     else:
        #         self.tab = VBox([row1, self.cells_plot])

        # style = {'description_width': '25%'}
        # # layout = {'width': '400px'}
        # widget_layout = {'width': '40%'}
        # self.rel_max_adhesion_dist = FloatText(description='max adhesion dist',
        #   value=1.25,
        #   step=0.1,
        #   style=style, layout=widget_layout)
        # self.adhesion_strength = FloatText(description='adhesion strength',
        #   value=0.5,
        #   step=0.1,
        #   style=style, layout=widget_layout)
        # self.repulsion_strength = FloatText(description='repulsion strength',
        #   value=0.5,
        #   step=0.1,
        #   style=style, layout=widget_layout)

        #----------
        style = {'description_width': '25%'}
        name_button_layout = {'width': '50%'}
        widget_layout = {'width': '20%'}

        param_name6 = Button(description='adhesion_strength',
                             disabled=True,
                             layout=name_button_layout)
        param_name6.style.button_color = 'tan'

        self.adhesion_strength = FloatText(value=0.5,
                                           step=0.1,
                                           style=style,
                                           layout=widget_layout)

        param_name7 = Button(description='repulsion_strength',
                             disabled=True,
                             layout=name_button_layout)
        param_name7.style.button_color = 'lightgreen'

        self.repulsion_strength = FloatText(value=0.5,
                                            step=0.1,
                                            style=style,
                                            layout=widget_layout)

        param_name8 = Button(description='rel_max_adhesion_dist',
                             disabled=True,
                             layout=name_button_layout)
        param_name8.style.button_color = 'tan'

        self.rel_max_adhesion_dist = FloatText(value=1.25,
                                               step=0.1,
                                               style=style,
                                               layout=widget_layout)

        row6 = [param_name6, self.adhesion_strength]
        row7 = [param_name7, self.repulsion_strength]
        row8 = [param_name8, self.rel_max_adhesion_dist]

        box_layout = Layout(display='flex',
                            flex_flow='row',
                            align_items='stretch',
                            width='100%')
        # box_layout = Layout(display='flex', flex_flow='row', align_items='stretch', width='50%')
        box6 = Box(children=row6, layout=box_layout)
        box7 = Box(children=row7, layout=box_layout)
        box8 = Box(children=row8, layout=box_layout)

        wbox6 = Box(children=row6, layout=box_layout)
        wbox7 = Box(children=row7, layout=box_layout)
        wbox8 = Box(children=row8, layout=box_layout)
        self.pwidgets = VBox([wbox6, wbox7, wbox8])

        #----------
        self.params = VBox([box6, box7, box8])
        self.plot_stuff = VBox([row1, self.cells_plot])
        # self.tab = HBox([self.plot_stuff, self.params])
        self.tab = HBox([self.plot_stuff, self.pwidgets], layout=box_layout)

    # def update(self, rdir=''):
    def update(self, rdir=''):
        # with debug_view:
        #     print("mcds_cells:update(): rdir=", rdir)

        if rdir:
            self.output_dir = rdir

        all_files = sorted(
            glob.glob(os.path.join(self.output_dir, 'output*.xml')))
        if len(all_files) > 0:
            last_file = all_files[-1]
            # Note! the following will trigger: self.max_frames.observe(self.update_max_frames)
            self.max_frames.value = int(
                last_file[-12:-4])  # assumes naming scheme: "snapshot%08d.svg"

        # with debug_view:
        #     print("mcds_cells": added %s files" % len(all_files))

    # def download_cb(self):
    #     file_str = os.path.join(self.output_dir, '*.svg')
    #     # print('zip up all ',file_str)
    #     with zipfile.ZipFile('svg.zip', 'w') as myzip:
    #         for f in glob.glob(file_str):
    #             myzip.write(f, os.path.basename(f))   # 2nd arg avoids full filename path in the archive

    def show_nucleus_cb(self, b):
        global current_frame
        if (self.show_nucleus_checkbox.value):
            self.show_nucleus = 1
        else:
            self.show_nucleus = 0
#        self.plot_cells(self,current_frame)
        self.cells_plot.update()

    def show_edge_cb(self, b):
        if (self.show_edge_checkbox.value):
            self.show_edge = 1
        else:
            self.show_edge = 0
        self.cells_plot.update()

    # def show_tracks_cb(self, b):
    #     if (self.show_tracks_checkbox.value):
    #         self.show_tracks = 1
    #     else:
    #         self.show_tracks = 0
    #     # print('--- show_tracks_cb: calling cells_plot.update()')
    #     # if (not self.show_tracks):
    #     #     self.cells_plot.update()
    #     # else:
    #     if (self.show_tracks):
    #         self.create_all_tracks()
    #     self.cells_plot.update()

    # Note! this is called for EACH change to "Max" frames, which is with every new .svg file created!
    def update_max_frames(self, _b):
        self.cells_plot.children[0].max = self.max_frames.value
        # if (self.show_tracks):
        #     print('--- update_max_frames: calling create_all_tracks')
        #     self.create_all_tracks()

    #-----------------------------------------------------
    def circles(self, x, y, s, c='b', vmin=None, vmax=None, **kwargs):
        """
        See https://gist.github.com/syrte/592a062c562cd2a98a83 

        Make a scatter plot of circles. 
        Similar to plt.scatter, but the size of circles are in data scale.
        Parameters
        ----------
        x, y : scalar or array_like, shape (n, )
            Input data
        s : scalar or array_like, shape (n, ) 
            Radius of circles.
        c : color or sequence of color, optional, default : 'b'
            `c` can be a single color format string, or a sequence of color
            specifications of length `N`, or a sequence of `N` numbers to be
            mapped to colors using the `cmap` and `norm` specified via kwargs.
            Note that `c` should not be a single numeric RGB or RGBA sequence 
            because that is indistinguishable from an array of values
            to be colormapped. (If you insist, use `color` instead.)  
            `c` can be a 2-D array in which the rows are RGB or RGBA, however. 
        vmin, vmax : scalar, optional, default: None
            `vmin` and `vmax` are used in conjunction with `norm` to normalize
            luminance data.  If either are `None`, the min and max of the
            color array is used.
        kwargs : `~matplotlib.collections.Collection` properties
            Eg. alpha, edgecolor(ec), facecolor(fc), linewidth(lw), linestyle(ls), 
            norm, cmap, transform, etc.
        Returns
        -------
        paths : `~matplotlib.collections.PathCollection`
        Examples
        --------
        a = np.arange(11)
        circles(a, a, s=a*0.2, c=a, alpha=0.5, ec='none')
        plt.colorbar()
        License
        --------
        This code is under [The BSD 3-Clause License]
        (http://opensource.org/licenses/BSD-3-Clause)
        """

        if np.isscalar(c):
            kwargs.setdefault('color', c)
            c = None

        if 'fc' in kwargs:
            kwargs.setdefault('facecolor', kwargs.pop('fc'))
        if 'ec' in kwargs:
            kwargs.setdefault('edgecolor', kwargs.pop('ec'))
        if 'ls' in kwargs:
            kwargs.setdefault('linestyle', kwargs.pop('ls'))
        if 'lw' in kwargs:
            kwargs.setdefault('linewidth', kwargs.pop('lw'))
        # You can set `facecolor` with an array for each patch,
        # while you can only set `facecolors` with a value for all.

        zipped = np.broadcast(x, y, s)
        patches = [Circle((x_, y_), s_) for x_, y_, s_ in zipped]
        collection = PatchCollection(patches, **kwargs)
        if c is not None:
            c = np.broadcast_to(c, zipped.shape).ravel()
            collection.set_array(c)
            collection.set_clim(vmin, vmax)

        ax = plt.gca()
        ax.add_collection(collection)
        ax.autoscale_view()
        # plt.draw_if_interactive()
        if c is not None:
            plt.sci(collection)
        # return collection

    #-------------------------
    # def plot_cells(self, frame, rdel=''):
    def plot_cells(self, frame):
        # global current_idx, axes_max
        global current_frame
        current_frame = frame
        fname = "output%08d.xml" % frame
        full_fname = os.path.join(self.output_dir, fname)
        # with debug_view:
        # print("plot_cells:", full_fname)
        # print("-- plot_cells:", full_fname)
        if not os.path.isfile(full_fname):
            print("Once output files are generated, click the slider.")
            return

        mcds = pyMCDS(fname, self.output_dir)
        # print(mcds.get_time())

        cell_ids = mcds.data['discrete_cells']['ID']
        #        print(cell_ids.shape)
        #        print(cell_ids[:4])

        #cell_vec = np.zeros((cell_ids.shape, 3))
        num_cells = cell_ids.shape[0]
        cell_vec = np.zeros((cell_ids.shape[0], 3))
        vec_list = ['position_x', 'position_y', 'position_z']
        for i, lab in enumerate(vec_list):
            cell_vec[:, i] = mcds.data['discrete_cells'][lab]
        xvals = cell_vec[:, 0]
        yvals = cell_vec[:, 1]
        # print('x range: ',xvals.min(), xvals.max())
        # print('y range: ',yvals.min(), yvals.max())

        # xvals = np.array(xlist)
        # yvals = np.array(ylist)
        # rvals = np.array(rlist)
        # rgbs = np.array(rgb_list)

        # print("xvals[0:5]=",xvals[0:5])
        # print("rvals[0:5]=",rvals[0:5])
        # print("rvals.min, max=",rvals.min(),rvals.max())

        # rwh - is this where I change size of render window?? (YES - yipeee!)
        #   plt.figure(figsize=(6, 6))
        #   plt.cla()
        # title_str = svals[2] + "d, " + svals[4] + "h, " + svals[7] + "m"
        title_str = str(
            mcds.get_time()) + " min (" + str(num_cells) + " agents)"
        #   plt.title(title_str)
        #   plt.xlim(axes_min,axes_max)
        #   plt.ylim(axes_min,axes_max)
        #   plt.scatter(xvals,yvals, s=rvals*scale_radius, c=rgbs)

        # TODO: make figsize a function of plot_size? What about non-square plots?
        # self.fig = plt.figure(figsize=(18, 18))
        # self.fig = plt.figure(figsize=(15, 15))  #
        # self.fig = plt.figure(figsize=(9, 9))  #
        self.fig = plt.figure(figsize=(8, 8))  #

        #rwh - temp fix - Ah, error only occurs when "edges" is toggled on
        # cell_size = 5
        cell_vols = mcds.data['discrete_cells']['total_volume']
        cell_radii = (cell_vols * 0.75 / 3.14159)**0.3333
        # if (self.show_edge):
        #     try:
        #         # self.circles(xvals,yvals, s=rvals, color=rgbs, edgecolor='black', linewidth=0.5)
        #         # self.circles(xvals,yvals, s=cell_radii, edgecolor='black', linewidth=0.1)
        #         # self.circles(xvals,yvals, s=cell_radii, c='red', edgecolor='black', linewidth=0.5, fc='none')
        #         self.circles(xvals,yvals, s=cell_radii, c='red', fc='none')
        #         # cell_circles = self.circles(xvals,yvals, s=rvals, color=rgbs, edgecolor='black', linewidth=0.5)
        #         # plt.sci(cell_circles)
        #     except (ValueError):
        #         pass
        # else:
        #     # self.circles(xvals,yvals, s=rvals, color=rgbs)
        #     self.circles(xvals,yvals, s=cell_radii, fc='none')
        self.circles(xvals, yvals, s=cell_radii, fc='none')

        plt.xlim(self.axes_min, self.axes_max)
        plt.ylim(self.axes_min, self.axes_max)
        #   ax.grid(False)
        #        axx.set_title(title_str)
        plt.title(title_str)

# video-style widget - perhaps for future use
# cells_play = widgets.Play(
#     interval=1,
#     value=50,
#     min=0,
#     max=100,
#     step=1,
#     description="Press play",
#     disabled=False,
# )
# def cells_slider_change(change):
#     print('cells_slider_change: type(change)=',type(change),change.new)
#     plot_cells(change.new)

#cells_play
# cells_slider = widgets.IntSlider()
# cells_slider.observe(cells_slider_change, names='value')

# widgets.jslink((cells_play, 'value'), (cells_slider,'value')) #  (cells_slider, 'value'), (plot_cells, 'value'))

# cells_slider = widgets.IntSlider()
# widgets.jslink((play, 'value'), (slider, 'value'))
# widgets.HBox([cells_play, cells_slider])

# Using the following generates a new mpl plot; it doesn't use the existing plot!
#cells_anim = widgets.HBox([cells_play, cells_slider])

#cells_tab = widgets.VBox([cells_dir, cells_plot, cells_anim], layout=tab_layout)

#cells_tab = widgets.VBox([cells_dir, cells_anim], layout=tab_layout)
#---------------------

# Populate the GUI widgets with values from the XML

    def fill_gui(self, xml_root):
        uep = xml_root.find('.//user_parameters')  # find unique entry point
        self.adhesion_strength.value = float(
            uep.find('.//adhesion_strength').text)
        self.repulsion_strength.value = float(
            uep.find('.//repulsion_strength').text)
        self.rel_max_adhesion_dist.value = float(
            uep.find('.//rel_max_adhesion_dist').text)

    # Read values from the GUI widgets and generate/write a new XML
    def fill_xml(self, xml_root):
        # TODO: verify template .xml file exists!

        # TODO: verify valid type (numeric) and range?
        uep = xml_root.find('.//user_parameters')  # find unique entry point
        uep.find('.//adhesion_strength').text = str(
            self.adhesion_strength.value)
        uep.find('.//repulsion_strength').text = str(
            self.repulsion_strength.value)
        uep.find('.//rel_max_adhesion_dist').text = str(
            self.rel_max_adhesion_dist.value)
Exemple #21
0
class SubstrateTab(object):

    def __init__(self):
        
        self.output_dir = '.'
        # self.output_dir = 'tmpdir'

        self.figsize_width_substrate = 15.0  # allow extra for colormap
        self.figsize_height_substrate = 12.5
        self.figsize_width_svg = 12.0
        self.figsize_height_svg = 12.0

        # self.fig = plt.figure(figsize=(7.2,6))  # this strange figsize results in a ~square contour plot

        self.first_time = True
        self.modulo = 1

        self.use_defaults = True

        self.svg_delta_t = 1
        self.substrate_delta_t = 1
        self.svg_frame = 1
        self.substrate_frame = 1

        self.customized_output_freq = False
        self.therapy_activation_time = 1000000
        self.max_svg_frame_pre_therapy = 1000000
        self.max_substrate_frame_pre_therapy = 1000000

        self.svg_xmin = 0

        # Probably don't want to hardwire these if we allow changing the domain size
        # self.svg_xrange = 2000
        # self.xmin = -1000.
        # self.xmax = 1000.
        # self.ymin = -1000.
        # self.ymax = 1000.
        # self.x_range = 2000.
        # self.y_range = 2000.

        self.show_nucleus = True
        self.show_edge = True

        # initial value
        self.field_index = 4
        # self.field_index = self.mcds_field.value + 4

        self.skip_cb = False

        # define dummy size of mesh (set in the tool's primary module)
        self.numx = 0
        self.numy = 0

        self.title_str = ''

        tab_height = '600px'
        tab_height = '500px'
        constWidth = '180px'
        constWidth2 = '150px'
        tab_layout = Layout(width='900px',   # border='2px solid black',
                            height=tab_height, ) #overflow_y='scroll')

        max_frames = 1   
        # self.mcds_plot = interactive(self.plot_substrate, frame=(0, max_frames), continuous_update=False)  
        # self.i_plot = interactive(self.plot_plots, frame=(0, max_frames), continuous_update=False)  
        self.i_plot = interactive(self.plot_substrate, frame=(0, max_frames), continuous_update=False)  

        # "plot_size" controls the size of the tab height, not the plot (rf. figsize for that)
        # NOTE: the Substrates Plot tab has an extra row of widgets at the top of it (cf. Cell Plots tab)
        svg_plot_size = '700px'
        svg_plot_size = '600px'
        svg_plot_size = '700px'
        svg_plot_size = '900px'
        self.i_plot.layout.width = svg_plot_size
        self.i_plot.layout.height = svg_plot_size

        self.fontsize = 20

            # description='# cell frames',
        self.max_frames = BoundedIntText(
            min=0, max=99999, value=max_frames,
            description='# frames',
           layout=Layout(width='160px'),
        )
        self.max_frames.observe(self.update_max_frames)

        # self.field_min_max = {'dummy': [0., 1., False]}
        # NOTE: manually setting these for now (vs. parsing them out of data/initial.xml)
        self.field_min_max = {'director signal':[0.,1.,False], 'cargo signal':[0.,1.,False] }
        # hacky I know, but make a dict that's got (key,value) reversed from the dict in the Dropdown below
        # self.field_dict = {0:'dummy'}
        self.field_dict = {0:'director signal', 1:'cargo signal'}

        self.mcds_field = Dropdown(
            options={'director signal': 0, 'cargo signal':1},
            value=0,
            #     description='Field',
           layout=Layout(width=constWidth)
        )
        # print("substrate __init__: self.mcds_field.value=",self.mcds_field.value)
#        self.mcds_field.observe(self.mcds_field_cb)
        self.mcds_field.observe(self.mcds_field_changed_cb)

        self.field_cmap = Dropdown(
            options=['viridis', 'jet', 'YlOrRd'],
            value='YlOrRd',
            #     description='Field',
           layout=Layout(width=constWidth)
        )
#        self.field_cmap.observe(self.plot_substrate)
        self.field_cmap.observe(self.mcds_field_cb)

        self.cmap_fixed_toggle = Checkbox(
            description='Fix',
            disabled=False,
#           layout=Layout(width=constWidth2),
        )
        self.cmap_fixed_toggle.observe(self.mcds_field_cb)

#         def cmap_fixed_toggle_cb(b):
#             # self.update()
# #            self.field_min_max = {'oxygen': [0., 30.,True], 'glucose': [0., 1.,False]}
#             field_name = self.field_dict[self.mcds_field.value]
#             if (self.cmap_fixed_toggle.value):  
#                 self.field_min_max[field_name][0] = self.cmap_min.value
#                 self.field_min_max[field_name][1] = self.cmap_max.value
#                 self.field_min_max[field_name][2] = True
#             else:
#                 # self.field_min_max[field_name][0] = self.cmap_min.value
#                 # self.field_min_max[field_name][1] = self.cmap_max.value
#                 self.field_min_max[field_name][2] = False
#             self.i_plot.update()

        # self.cmap_fixed_toggle.observe(cmap_fixed_toggle_cb)

#         self.save_min_max= Button(
#             description='Save', #style={'description_width': 'initial'},
#             button_style='success',  # 'success', 'info', 'warning', 'danger' or ''
#             tooltip='Save min/max for this substrate',
#             disabled=True,
#            layout=Layout(width='90px')
#         )

#         def save_min_max_cb(b):
# #            field_name = self.mcds_field.options[]
# #            field_name = next(key for key, value in self.mcds_field.options.items() if value == self.mcds_field.value)
#             field_name = self.field_dict[self.mcds_field.value]
# #            print(field_name)
# #            self.field_min_max = {'oxygen': [0., 30.], 'glucose': [0., 1.], 'H+ ions': [0., 1.], 'ECM': [0., 1.], 'NP1': [0., 1.], 'NP2': [0., 1.]}
#             self.field_min_max[field_name][0] = self.cmap_min.value
#             self.field_min_max[field_name][1] = self.cmap_max.value
# #            print(self.field_min_max)

#         self.save_min_max.on_click(save_min_max_cb)


        self.cmap_min = FloatText(
            description='Min',
            value=0,
            step = 0.1,
            disabled=True,
            layout=Layout(width=constWidth2),
        )
        self.cmap_min.observe(self.mcds_field_cb)

        self.cmap_max = FloatText(
            description='Max',
            value=38,
            step = 0.1,
            disabled=True,
            layout=Layout(width=constWidth2),
        )
        self.cmap_max.observe(self.mcds_field_cb)

        def cmap_fixed_toggle_cb(b):
            field_name = self.field_dict[self.mcds_field.value]
            # print(self.cmap_fixed_toggle.value)
            if (self.cmap_fixed_toggle.value):  # toggle on fixed range
                self.cmap_min.disabled = False
                self.cmap_max.disabled = False
                self.field_min_max[field_name][0] = self.cmap_min.value
                self.field_min_max[field_name][1] = self.cmap_max.value
                self.field_min_max[field_name][2] = True
                # self.save_min_max.disabled = False
            else:  # toggle off fixed range
                self.cmap_min.disabled = True
                self.cmap_max.disabled = True
                self.field_min_max[field_name][2] = False
                # self.save_min_max.disabled = True
#            self.mcds_field_cb()
            self.i_plot.update()

        self.cmap_fixed_toggle.observe(cmap_fixed_toggle_cb)

        field_cmap_row2 = HBox([self.field_cmap, self.cmap_fixed_toggle])

#        field_cmap_row3 = HBox([self.save_min_max, self.cmap_min, self.cmap_max])
        items_auto = [
            # self.save_min_max, #layout=Layout(flex='3 1 auto', width='auto'),
            self.cmap_min, 
            self.cmap_max,  
         ]
        box_layout = Layout(display='flex',
                    flex_flow='row',
                    align_items='stretch',
                    width='80%')
        field_cmap_row3 = Box(children=items_auto, layout=box_layout)

        # self.debug_str = Text(
        #     value='debug info',
        #     description='Debug:',
        #     disabled=True,
        #     layout=Layout(width='600px'),  #constWidth = '180px'
        # )

        #---------------------
        self.cell_nucleus_toggle = Checkbox(
            description='nuclei',
            disabled=False,
            value = self.show_nucleus,
#           layout=Layout(width=constWidth2),
        )
        def cell_nucleus_toggle_cb(b):
            # self.update()
            if (self.cell_nucleus_toggle.value):  
                self.show_nucleus = True
            else:
                self.show_nucleus = False
            self.i_plot.update()

        self.cell_nucleus_toggle.observe(cell_nucleus_toggle_cb)

        #----
        self.cell_edges_toggle = Checkbox(
            description='edges',
            disabled=False,
            value=self.show_edge,
#           layout=Layout(width=constWidth2),
        )
        def cell_edges_toggle_cb(b):
            # self.update()
            if (self.cell_edges_toggle.value):  
                self.show_edge = True
            else:
                self.show_edge = False
            self.i_plot.update()

        self.cell_edges_toggle.observe(cell_edges_toggle_cb)

        self.cells_toggle = Checkbox(
            description='Cells',
            disabled=False,
            value=True,
#           layout=Layout(width=constWidth2),
        )
        def cells_toggle_cb(b):
            # self.update()
            self.i_plot.update()
            if (self.cells_toggle.value):
                self.cell_edges_toggle.disabled = False
                self.cell_nucleus_toggle.disabled = False
            else:
                self.cell_edges_toggle.disabled = True
                self.cell_nucleus_toggle.disabled = True

        self.cells_toggle.observe(cells_toggle_cb)

        #---------------------
        self.substrates_toggle = Checkbox(
            description='Substrates',
            disabled=False,
            value=True,
#           layout=Layout(width=constWidth2),
        )
        def substrates_toggle_cb(b):
            if (self.substrates_toggle.value):  # seems bass-ackwards
                self.cmap_fixed_toggle.disabled = False
                self.cmap_min.disabled = False
                self.cmap_max.disabled = False
                self.mcds_field.disabled = False
                self.field_cmap.disabled = False
            else:
                self.cmap_fixed_toggle.disabled = True
                self.cmap_min.disabled = True
                self.cmap_max.disabled = True
                self.mcds_field.disabled = True
                self.field_cmap.disabled = True

        self.substrates_toggle.observe(substrates_toggle_cb)

        self.grid_toggle = Checkbox(
            description='grid',
            disabled=False,
            value=True,
#           layout=Layout(width=constWidth2),
        )
        def grid_toggle_cb(b):
            # self.update()
            self.i_plot.update()

        self.grid_toggle.observe(grid_toggle_cb)

#        field_cmap_row3 = Box([self.save_min_max, self.cmap_min, self.cmap_max])

        # mcds_tab = widgets.VBox([mcds_dir, mcds_plot, mcds_play], layout=tab_layout)
        # mcds_params = VBox([self.mcds_field, field_cmap_row2, field_cmap_row3, self.max_frames])  # mcds_dir
#        mcds_params = VBox([self.mcds_field, field_cmap_row2, field_cmap_row3,])  # mcds_dir

#        self.tab = HBox([mcds_params, self.mcds_plot], layout=tab_layout)

        help_label = Label('select slider: drag or left/right arrows')
        # row1 = Box([help_label, Box( [self.max_frames, self.mcds_field, self.field_cmap], layout=Layout(border='0px solid black',
        row1a = Box( [self.max_frames, self.mcds_field, self.field_cmap], layout=Layout(border='1px solid black',
                            width='50%',
                            height='',
                            align_items='stretch',
                            flex_direction='row',
                            display='flex')) 
        row1b = Box( [self.cells_toggle, self.cell_nucleus_toggle, self.cell_edges_toggle], layout=Layout(border='1px solid black',
                            width='50%',
                            height='',
                            align_items='stretch',
                            flex_direction='row',
                            display='flex')) 
        row1 = HBox( [row1a, Label('.....'), row1b])

        row2a = Box([self.cmap_fixed_toggle, self.cmap_min, self.cmap_max], layout=Layout(border='1px solid black',
                            width='50%',
                            height='',
                            align_items='stretch',
                            flex_direction='row',
                            display='flex'))
        # row2b = Box( [self.substrates_toggle, self.grid_toggle], layout=Layout(border='1px solid black',
        row2b = Box( [self.substrates_toggle, ], layout=Layout(border='1px solid black',
                            width='50%',
                            height='',
                            align_items='stretch',
                            flex_direction='row',
                            display='flex')) 
        # row2 = HBox( [row2a, self.substrates_toggle, self.grid_toggle])
        row2 = HBox( [row2a, Label('.....'), row2b])

        if (hublib_flag):
            self.download_button = Download('mcds.zip', style='warning', icon='cloud-download', 
                                                tooltip='Download data', cb=self.download_cb)

            self.download_svg_button = Download('svg.zip', style='warning', icon='cloud-download', 
                                            tooltip='You need to allow pop-ups in your browser', cb=self.download_svg_cb)
            download_row = HBox([self.download_button.w, self.download_svg_button.w, Label("Download all cell plots (browser must allow pop-ups).")])

            # box_layout = Layout(border='0px solid')
            controls_box = VBox([row1, row2])  # ,width='50%', layout=box_layout)
            self.tab = VBox([controls_box, self.i_plot, download_row])
            # self.tab = VBox([controls_box, self.debug_str, self.i_plot, download_row])
        else:
            # self.tab = VBox([row1, row2])
            self.tab = VBox([row1, row2, self.i_plot])

    #---------------------------------------------------
    def update_dropdown_fields(self, data_dir):
        # print('update_dropdown_fields called --------')
        self.output_dir = data_dir
        tree = None
        try:
            fname = os.path.join(self.output_dir, "initial.xml")
            tree = ET.parse(fname)
            xml_root = tree.getroot()
        except:
            print("Cannot open ",fname," to read info, e.g., names of substrate fields.")
            return

        xml_root = tree.getroot()
        self.field_min_max = {}
        self.field_dict = {}
        dropdown_options = {}
        uep = xml_root.find('.//variables')
        comment_str = ""
        field_idx = 0
        if (uep):
            for elm in uep.findall('variable'):
                # print("-----> ",elm.attrib['name'])
                field_name = elm.attrib['name']
                self.field_min_max[field_name] = [0., 1., False]
                self.field_dict[field_idx] = field_name
                dropdown_options[field_name] = field_idx

                self.field_min_max[field_name][0] = 0   
                self.field_min_max[field_name][1] = 1

                # self.field_min_max[field_name][0] = field_idx   #rwh: helps debug
                # self.field_min_max[field_name][1] = field_idx+1   
                self.field_min_max[field_name][2] = False
                field_idx += 1

#        constWidth = '180px'
        # print('options=',dropdown_options)
        # print(self.field_min_max)  # debug
        self.mcds_field.value = 0
        self.mcds_field.options = dropdown_options
#         self.mcds_field = Dropdown(
# #            options={'oxygen': 0, 'glucose': 1},
#             options=dropdown_options,
#             value=0,
#             #     description='Field',
#            layout=Layout(width=constWidth)
#         )

    # def update_max_frames_expected(self, value):  # called when beginning an interactive Run
    #     self.max_frames.value = value  # assumes naming scheme: "snapshot%08d.svg"
    #     self.mcds_plot.children[0].max = self.max_frames.value

#------------------------------------------------------------------------------
    def update_params(self, config_tab, user_params_tab):
        # xml_root.find(".//x_min").text = str(self.xmin.value)
        # xml_root.find(".//x_max").text = str(self.xmax.value)
        # xml_root.find(".//dx").text = str(self.xdelta.value)
        # xml_root.find(".//y_min").text = str(self.ymin.value)
        # xml_root.find(".//y_max").text = str(self.ymax.value)
        # xml_root.find(".//dy").text = str(self.ydelta.value)
        # xml_root.find(".//z_min").text = str(self.zmin.value)
        # xml_root.find(".//z_max").text = str(self.zmax.value)
        # xml_root.find(".//dz").text = str(self.zdelta.value)

        self.xmin = config_tab.xmin.value 
        self.xmax = config_tab.xmax.value 
        self.x_range = self.xmax - self.xmin
        self.svg_xrange = self.xmax - self.xmin
        self.ymin = config_tab.ymin.value
        self.ymax = config_tab.ymax.value 
        self.y_range = self.ymax - self.ymin

        self.numx =  math.ceil( (self.xmax - self.xmin) / config_tab.xdelta.value)
        self.numy =  math.ceil( (self.ymax - self.ymin) / config_tab.ydelta.value)

        if (self.x_range > self.y_range):  
            ratio = self.y_range / self.x_range
            self.figsize_width_substrate = 15.0  # allow extra for colormap
            self.figsize_height_substrate = 12.5 * ratio
            self.figsize_width_svg = 12.0
            self.figsize_height_svg = 12.0 * ratio
        else:   # x < y
            ratio = self.x_range / self.y_range
            self.figsize_width_substrate = 15.0 * ratio 
            self.figsize_height_substrate = 12.5
            self.figsize_width_svg = 12.0 * ratio
            self.figsize_height_svg = 12.0 

        self.svg_flag = config_tab.toggle_svg.value
        self.substrates_flag = config_tab.toggle_mcds.value
        # print("substrates: update_params(): svg_flag, toggle=",self.svg_flag,config_tab.toggle_svg.value)        
        # print("substrates: update_params(): self.substrates_flag = ",self.substrates_flag)
        self.svg_delta_t = config_tab.svg_interval.value
        self.substrate_delta_t = config_tab.mcds_interval.value
        self.modulo = int(self.substrate_delta_t / self.svg_delta_t)
        # print("substrates: update_params(): modulo=",self.modulo)        

        if self.customized_output_freq:
#            self.therapy_activation_time = user_params_tab.therapy_activation_time.value   # NOTE: edit for user param name
            # print("substrates: update_params(): therapy_activation_time=",self.therapy_activation_time)
            self.max_svg_frame_pre_therapy = int(self.therapy_activation_time/self.svg_delta_t)
            self.max_substrate_frame_pre_therapy = int(self.therapy_activation_time/self.substrate_delta_t)

#------------------------------------------------------------------------------
#    def update(self, rdir):
#   Called from driver module (e.g., pc4*.py) (among other places?)
    def update(self, rdir=''):
        # with debug_view:
        #     print("substrates: update rdir=", rdir)        
        # print("substrates: update rdir=", rdir)        

        if rdir:
            self.output_dir = rdir

        # print('update(): self.output_dir = ', self.output_dir)

        if self.first_time:
        # if True:
            self.first_time = False
            full_xml_filename = Path(os.path.join(self.output_dir, 'config.xml'))
            # print("substrates: update(), config.xml = ",full_xml_filename)        
            # self.num_svgs = len(glob.glob(os.path.join(self.output_dir, 'snap*.svg')))
            # self.num_substrates = len(glob.glob(os.path.join(self.output_dir, 'output*.xml')))
            # print("substrates: num_svgs,num_substrates =",self.num_svgs,self.num_substrates)        
            # argh - no! If no files created, then denom = -1
            # self.modulo = int((self.num_svgs - 1) / (self.num_substrates - 1))
            # print("substrates: update(): modulo=",self.modulo)        
            if full_xml_filename.is_file():
                tree = ET.parse(str(full_xml_filename))  # this file cannot be overwritten; part of tool distro
                xml_root = tree.getroot()
                self.svg_delta_t = float(xml_root.find(".//SVG//interval").text)
                self.substrate_delta_t = float(xml_root.find(".//full_data//interval").text)
                # print("substrates: svg,substrate delta_t values=",self.svg_delta_t,self.substrate_delta_t)        
                self.modulo = int(self.substrate_delta_t / self.svg_delta_t)
                # print("substrates: update(): modulo=",self.modulo)        


        # all_files = sorted(glob.glob(os.path.join(self.output_dir, 'output*.xml')))  # if the substrates/MCDS

        all_files = sorted(glob.glob(os.path.join(self.output_dir, 'snap*.svg')))   # if .svg
        if len(all_files) > 0:
            last_file = all_files[-1]
            self.max_frames.value = int(last_file[-12:-4])  # assumes naming scheme: "snapshot%08d.svg"
        else:
            substrate_files = sorted(glob.glob(os.path.join(self.output_dir, 'output*.xml')))
            if len(substrate_files) > 0:
                last_file = substrate_files[-1]
                self.max_frames.value = int(last_file[-12:-4])

    def download_svg_cb(self):
        file_str = os.path.join(self.output_dir, '*.svg')
        # print('zip up all ',file_str)
        with zipfile.ZipFile('svg.zip', 'w') as myzip:
            for f in glob.glob(file_str):
                myzip.write(f, os.path.basename(f))   # 2nd arg avoids full filename path in the archive

    def download_cb(self):
        file_xml = os.path.join(self.output_dir, '*.xml')
        file_mat = os.path.join(self.output_dir, '*.mat')
        # print('zip up all ',file_str)
        with zipfile.ZipFile('mcds.zip', 'w') as myzip:
            for f in glob.glob(file_xml):
                myzip.write(f, os.path.basename(f)) # 2nd arg avoids full filename path in the archive
            for f in glob.glob(file_mat):
                myzip.write(f, os.path.basename(f))

    def update_max_frames(self,_b):
        self.i_plot.children[0].max = self.max_frames.value

    # called if user selected different substrate in dropdown
    def mcds_field_changed_cb(self, b):
        # print("mcds_field_changed_cb: self.mcds_field.value=",self.mcds_field.value)
        if (self.mcds_field.value == None):
            return
        self.field_index = self.mcds_field.value + 4

        field_name = self.field_dict[self.mcds_field.value]
        # print('mcds_field_changed_cb: field_name='+ field_name)
        # print(self.field_min_max[field_name])
        # self.debug_str.value = 'mcds_field_changed_cb: '+ field_name  + str(self.field_min_max[field_name])
        # self.debug_str.value = 'cb1: '+ str(self.field_min_max)

        # BEWARE of these triggering the mcds_field_cb() callback! Hence, the "skip_cb"
        self.skip_cb = True
        self.cmap_min.value = self.field_min_max[field_name][0]
        self.cmap_max.value = self.field_min_max[field_name][1]
        self.cmap_fixed_toggle.value = bool(self.field_min_max[field_name][2])
        self.skip_cb = False

        self.i_plot.update()

    # called if user provided different min/max values for colormap, or a different colormap
    def mcds_field_cb(self, b):
        if self.skip_cb:
            return

        self.field_index = self.mcds_field.value + 4

        field_name = self.field_dict[self.mcds_field.value]
        # print('mcds_field_cb: field_name='+ field_name)

        # print('mcds_field_cb: '+ field_name)
        self.field_min_max[field_name][0] = self.cmap_min.value 
        self.field_min_max[field_name][1] = self.cmap_max.value
        self.field_min_max[field_name][2] = self.cmap_fixed_toggle.value
        # print(self.field_min_max[field_name])
        # self.debug_str.value = 'mcds_field_cb: ' + field_name + str(self.field_min_max[field_name])
        # self.debug_str.value = 'cb2: '+ str(self.field_min_max)
        # print('--- cb2: '+ str(self.field_min_max))  #rwh2
        # self.cmap_fixed_toggle.value = self.field_min_max[field_name][2]

        # field_name = self.mcds_field.options[self.mcds_field.value]
        # self.cmap_min.value = self.field_min_max[field_name][0]  # oxygen, etc
        # self.cmap_max.value = self.field_min_max[field_name][1]  # oxygen, etc

#        self.field_index = self.mcds_field.value + 4
#        print('field_index=',self.field_index)
        self.i_plot.update()


    #---------------------------------------------------------------------------
    def circles(self, x, y, s, c='b', vmin=None, vmax=None, **kwargs):
        """
        See https://gist.github.com/syrte/592a062c562cd2a98a83 

        Make a scatter plot of circles. 
        Similar to plt.scatter, but the size of circles are in data scale.
        Parameters
        ----------
        x, y : scalar or array_like, shape (n, )
            Input data
        s : scalar or array_like, shape (n, ) 
            Radius of circles.
        c : color or sequence of color, optional, default : 'b'
            `c` can be a single color format string, or a sequence of color
            specifications of length `N`, or a sequence of `N` numbers to be
            mapped to colors using the `cmap` and `norm` specified via kwargs.
            Note that `c` should not be a single numeric RGB or RGBA sequence 
            because that is indistinguishable from an array of values
            to be colormapped. (If you insist, use `color` instead.)  
            `c` can be a 2-D array in which the rows are RGB or RGBA, however. 
        vmin, vmax : scalar, optional, default: None
            `vmin` and `vmax` are used in conjunction with `norm` to normalize
            luminance data.  If either are `None`, the min and max of the
            color array is used.
        kwargs : `~matplotlib.collections.Collection` properties
            Eg. alpha, edgecolor(ec), facecolor(fc), linewidth(lw), linestyle(ls), 
            norm, cmap, transform, etc.
        Returns
        -------
        paths : `~matplotlib.collections.PathCollection`
        Examples
        --------
        a = np.arange(11)
        circles(a, a, s=a*0.2, c=a, alpha=0.5, ec='none')
        plt.colorbar()
        License
        --------
        This code is under [The BSD 3-Clause License]
        (http://opensource.org/licenses/BSD-3-Clause)
        """

        if np.isscalar(c):
            kwargs.setdefault('color', c)
            c = None

        if 'fc' in kwargs:
            kwargs.setdefault('facecolor', kwargs.pop('fc'))
        if 'ec' in kwargs:
            kwargs.setdefault('edgecolor', kwargs.pop('ec'))
        if 'ls' in kwargs:
            kwargs.setdefault('linestyle', kwargs.pop('ls'))
        if 'lw' in kwargs:
            kwargs.setdefault('linewidth', kwargs.pop('lw'))
        # You can set `facecolor` with an array for each patch,
        # while you can only set `facecolors` with a value for all.

        zipped = np.broadcast(x, y, s)
        patches = [Circle((x_, y_), s_)
                for x_, y_, s_ in zipped]
        collection = PatchCollection(patches, **kwargs)
        if c is not None:
            c = np.broadcast_to(c, zipped.shape).ravel()
            collection.set_array(c)
            collection.set_clim(vmin, vmax)

        ax = plt.gca()
        ax.add_collection(collection)
        ax.autoscale_view()
        # plt.draw_if_interactive()
        if c is not None:
            plt.sci(collection)
        # return collection

    #------------------------------------------------------------
    # def plot_svg(self, frame, rdel=''):
    def plot_svg(self, frame):
        # global current_idx, axes_max
        global current_frame
        current_frame = frame
        fname = "snapshot%08d.svg" % frame
        full_fname = os.path.join(self.output_dir, fname)
        # with debug_view:
            # print("plot_svg:", full_fname) 
        # print("-- plot_svg:", full_fname) 
        if not os.path.isfile(full_fname):
            print("Once output files are generated, click the slider.")   
            return

        xlist = deque()
        ylist = deque()
        rlist = deque()
        rgb_list = deque()

        #  print('\n---- ' + fname + ':')
#        tree = ET.parse(fname)
        tree = ET.parse(full_fname)
        root = tree.getroot()
        #  print('--- root.tag ---')
        #  print(root.tag)
        #  print('--- root.attrib ---')
        #  print(root.attrib)
        #  print('--- child.tag, child.attrib ---')
        numChildren = 0
        for child in root:
            #    print(child.tag, child.attrib)
            #    print("keys=",child.attrib.keys())
            if self.use_defaults and ('width' in child.attrib.keys()):
                self.axes_max = float(child.attrib['width'])
                # print("debug> found width --> axes_max =", axes_max)
            if child.text and "Current time" in child.text:
                svals = child.text.split()
                # remove the ".00" on minutes
                self.title_str += "   cells: " + svals[2] + "d, " + svals[4] + "h, " + svals[7][:-3] + "m"

                # self.cell_time_mins = int(svals[2])*1440 + int(svals[4])*60 + int(svals[7][:-3])
                # self.title_str += "   cells: " + str(self.cell_time_mins) + "m"   # rwh

            # print("width ",child.attrib['width'])
            # print('attrib=',child.attrib)
            # if (child.attrib['id'] == 'tissue'):
            if ('id' in child.attrib.keys()):
                # print('-------- found tissue!!')
                tissue_parent = child
                break

        # print('------ search tissue')
        cells_parent = None

        for child in tissue_parent:
            # print('attrib=',child.attrib)
            if (child.attrib['id'] == 'cells'):
                # print('-------- found cells, setting cells_parent')
                cells_parent = child
                break
            numChildren += 1

        num_cells = 0
        #  print('------ search cells')
        for child in cells_parent:
            #    print(child.tag, child.attrib)
            #    print('attrib=',child.attrib)
            for circle in child:  # two circles in each child: outer + nucleus
                #  circle.attrib={'cx': '1085.59','cy': '1225.24','fill': 'rgb(159,159,96)','r': '6.67717','stroke': 'rgb(159,159,96)','stroke-width': '0.5'}
                #      print('  --- cx,cy=',circle.attrib['cx'],circle.attrib['cy'])
                xval = float(circle.attrib['cx'])

                # map SVG coords into comp domain
                # xval = (xval-self.svg_xmin)/self.svg_xrange * self.x_range + self.xmin
                xval = xval/self.x_range * self.x_range + self.xmin

                s = circle.attrib['fill']
                # print("s=",s)
                # print("type(s)=",type(s))
                if (s[0:3] == "rgb"):  # if an rgb string, e.g. "rgb(175,175,80)" 
                    rgb = list(map(int, s[4:-1].split(",")))  
                    rgb[:] = [x / 255. for x in rgb]
                else:     # otherwise, must be a color name
                    rgb_tuple = mplc.to_rgb(mplc.cnames[s])  # a tuple
                    rgb = [x for x in rgb_tuple]

                # test for bogus x,y locations (rwh TODO: use max of domain?)
                too_large_val = 10000.
                if (np.fabs(xval) > too_large_val):
                    print("bogus xval=", xval)
                    break
                yval = float(circle.attrib['cy'])
                # yval = (yval - self.svg_xmin)/self.svg_xrange * self.y_range + self.ymin
                yval = yval/self.y_range * self.y_range + self.ymin
                if (np.fabs(yval) > too_large_val):
                    print("bogus xval=", xval)
                    break

                rval = float(circle.attrib['r'])
                # if (rgb[0] > rgb[1]):
                #     print(num_cells,rgb, rval)
                xlist.append(xval)
                ylist.append(yval)
                rlist.append(rval)
                rgb_list.append(rgb)

                # For .svg files with cells that *have* a nucleus, there will be a 2nd
                if (not self.show_nucleus):
                #if (not self.show_nucleus):
                    break

            num_cells += 1

            # if num_cells > 3:   # for debugging
            #   print(fname,':  num_cells= ',num_cells," --- debug exit.")
            #   sys.exit(1)
            #   break

            # print(fname,':  num_cells= ',num_cells)

        xvals = np.array(xlist)
        yvals = np.array(ylist)
        rvals = np.array(rlist)
        rgbs = np.array(rgb_list)
        # print("xvals[0:5]=",xvals[0:5])
        # print("rvals[0:5]=",rvals[0:5])
        # print("rvals.min, max=",rvals.min(),rvals.max())

        # rwh - is this where I change size of render window?? (YES - yipeee!)
        #   plt.figure(figsize=(6, 6))
        #   plt.cla()
        # if (self.substrates_toggle.value):
        self.title_str += " (" + str(num_cells) + " agents)"
            # title_str = " (" + str(num_cells) + " agents)"
        # else:
            # mins= round(int(float(root.find(".//current_time").text)))  # TODO: check units = mins
            # hrs = int(mins/60)
            # days = int(hrs/24)
            # title_str = '%dd, %dh, %dm' % (int(days),(hrs%24), mins - (hrs*60))
        plt.title(self.title_str)

        plt.xlim(self.xmin, self.xmax)
        plt.ylim(self.ymin, self.ymax)

        #   plt.xlim(axes_min,axes_max)
        #   plt.ylim(axes_min,axes_max)
        #   plt.scatter(xvals,yvals, s=rvals*scale_radius, c=rgbs)

        # TODO: make figsize a function of plot_size? What about non-square plots?
        # self.fig = plt.figure(figsize=(9, 9))

#        axx = plt.axes([0, 0.05, 0.9, 0.9])  # left, bottom, width, height
#        axx = fig.gca()
#        print('fig.dpi=',fig.dpi) # = 72

        #   im = ax.imshow(f.reshape(100,100), interpolation='nearest', cmap=cmap, extent=[0,20, 0,20])
        #   ax.xlim(axes_min,axes_max)
        #   ax.ylim(axes_min,axes_max)

        # convert radii to radii in pixels
        # ax2 = self.fig.gca()
        # N = len(xvals)
        # rr_pix = (ax2.transData.transform(np.vstack([rvals, rvals]).T) -
        #             ax2.transData.transform(np.vstack([np.zeros(N), np.zeros(N)]).T))
        # rpix, _ = rr_pix.T

        # markers_size = (144. * rpix / self.fig.dpi)**2   # = (2*rpix / fig.dpi * 72)**2
        # markers_size = markers_size/4000000.
        # print('max=',markers_size.max())

        #rwh - temp fix - Ah, error only occurs when "edges" is toggled on
        if (self.show_edge):
            try:
                # plt.scatter(xvals,yvals, s=markers_size, c=rgbs, edgecolor='black', linewidth=0.5)
                self.circles(xvals,yvals, s=rvals, color=rgbs, edgecolor='black', linewidth=0.5)
                # cell_circles = self.circles(xvals,yvals, s=rvals, color=rgbs, edgecolor='black', linewidth=0.5)
                # plt.sci(cell_circles)
            except (ValueError):
                pass
        else:
            # plt.scatter(xvals,yvals, s=markers_size, c=rgbs)
            self.circles(xvals,yvals, s=rvals, color=rgbs)

        # if (self.show_tracks):
        #     for key in self.trackd.keys():
        #         xtracks = self.trackd[key][:,0]
        #         ytracks = self.trackd[key][:,1]
        #         plt.plot(xtracks[0:frame],ytracks[0:frame],  linewidth=5)

        # plt.xlim(self.axes_min, self.axes_max)
        # plt.ylim(self.axes_min, self.axes_max)
        #   ax.grid(False)
#        axx.set_title(title_str)
        # plt.title(title_str)

    #---------------------------------------------------------------------------
    # assume "frame" is cell frame #, unless Cells is togggled off, then it's the substrate frame #
    # def plot_substrate(self, frame, grid):
    def plot_substrate(self, frame):
        # global current_idx, axes_max, gFileId, field_index

        # print("plot_substrate(): frame*self.substrate_delta_t  = ",frame*self.substrate_delta_t)
        # print("plot_substrate(): frame*self.svg_delta_t  = ",frame*self.svg_delta_t)
        self.title_str = ''

        # Recall:
        # self.svg_delta_t = config_tab.svg_interval.value
        # self.substrate_delta_t = config_tab.mcds_interval.value
        # self.modulo = int(self.substrate_delta_t / self.svg_delta_t)
        # self.therapy_activation_time = user_params_tab.therapy_activation_time.value

        # print("plot_substrate(): pre_therapy: max svg, substrate frames = ",max_svg_frame_pre_therapy, max_substrate_frame_pre_therapy)

        # Assume: # .svg files >= # substrate files
#        if (self.cells_toggle.value):

        # if (self.substrates_toggle.value and frame*self.substrate_delta_t <= self.svg_frame*self.svg_delta_t):
        # if (self.substrates_toggle.value and (frame % self.modulo == 0)):
        if (self.substrates_toggle.value):
            # self.fig = plt.figure(figsize=(14, 15.6))
            # self.fig = plt.figure(figsize=(15.0, 12.5))
            self.fig = plt.figure(figsize=(self.figsize_width_substrate, self.figsize_height_substrate))

            # rwh - funky way to figure out substrate frame for pc4cancerbots (due to user-defined "save_interval*")
            # self.cell_time_mins 
            # self.substrate_frame = int(frame / self.modulo)
            if (self.customized_output_freq and (frame > self.max_svg_frame_pre_therapy)):
                # max_svg_frame_pre_therapy = int(self.therapy_activation_time/self.svg_delta_t)
                # max_substrate_frame_pre_therapy = int(self.therapy_activation_time/self.substrate_delta_t)
                self.substrate_frame = self.max_substrate_frame_pre_therapy + (frame - self.max_svg_frame_pre_therapy)
            else:
                self.substrate_frame = int(frame / self.modulo)

            # print("plot_substrate(): self.substrate_frame=",self.substrate_frame)        

            # if (self.substrate_frame > (self.num_substrates-1)):
                # self.substrate_frame = self.num_substrates-1

            # print('self.substrate_frame = ',self.substrate_frame)
            # if (self.cells_toggle.value):
            #     self.modulo = int((self.num_svgs - 1) / (self.num_substrates - 1))
            #     self.substrate_frame = frame % self.modulo
            # else:
            #     self.substrate_frame = frame 
            fname = "output%08d_microenvironment0.mat" % self.substrate_frame
            xml_fname = "output%08d.xml" % self.substrate_frame
            # fullname = output_dir_str + fname

    #        fullname = fname
            full_fname = os.path.join(self.output_dir, fname)
            # print("--- plot_substrate(): full_fname=",full_fname)
            full_xml_fname = os.path.join(self.output_dir, xml_fname)
    #        self.output_dir = '.'

    #        if not os.path.isfile(fullname):
            if not os.path.isfile(full_fname):
                print("Once output files are generated, click the slider.")  # No:  output00000000_microenvironment0.mat
                return

    #        tree = ET.parse(xml_fname)
            tree = ET.parse(full_xml_fname)
            xml_root = tree.getroot()
            mins = round(int(float(xml_root.find(".//current_time").text)))  # TODO: check units = mins
            self.substrate_mins= round(int(float(xml_root.find(".//current_time").text)))  # TODO: check units = mins

            hrs = int(mins/60)
            days = int(hrs/24)
            self.title_str = 'substrate: %dd, %dh, %dm' % (int(days),(hrs%24), mins - (hrs*60))
            # self.title_str = 'substrate: %dm' % (mins )   # rwh


            info_dict = {}
    #        scipy.io.loadmat(fullname, info_dict)
            scipy.io.loadmat(full_fname, info_dict)
            M = info_dict['multiscale_microenvironment']
            #     global_field_index = int(mcds_field.value)
            #     print('plot_substrate: field_index =',field_index)
            f = M[self.field_index, :]   # 4=tumor cells field, 5=blood vessel density, 6=growth substrate
            # plt.clf()
            # my_plot = plt.imshow(f.reshape(400,400), cmap='jet', extent=[0,20, 0,20])
        
            # self.fig = plt.figure(figsize=(18.0,15))  # this strange figsize results in a ~square contour plot

            # plt.subplot(grid[0:1, 0:1])
            # main_ax = self.fig.add_subplot(grid[0:1, 0:1])  # works, but tiny upper-left region
            #main_ax = self.fig.add_subplot(grid[0:2, 0:2])
            # main_ax = self.fig.add_subplot(grid[0:, 0:2])
            #main_ax = self.fig.add_subplot(grid[:-1, 0:])   # nrows, ncols
            #main_ax = self.fig.add_subplot(grid[0:, 0:])   # nrows, ncols
            #main_ax = self.fig.add_subplot(grid[0:4, 0:])   # nrows, ncols


            # main_ax = self.fig.add_subplot(grid[0:3, 0:])   # nrows, ncols
            # main_ax = self.fig.add_subplot(111)   # nrows, ncols


            # plt.rc('font', size=10)  # TODO: does this affect the Cell plots fonts too? YES. Not what we want.

            #     fig.set_tight_layout(True)
            #     ax = plt.axes([0, 0.05, 0.9, 0.9 ]) #left, bottom, width, height
            #     ax = plt.axes([0, 0.0, 1, 1 ])
            #     cmap = plt.cm.viridis # Blues, YlOrBr, ...
            #     im = ax.imshow(f.reshape(100,100), interpolation='nearest', cmap=cmap, extent=[0,20, 0,20])
            #     ax.grid(False)

            # print("substrates.py: ------- numx, numy = ", self.numx, self.numy )
            # if (self.numx == 0):   # need to parse vals from the config.xml
            #     # print("--- plot_substrate(): full_fname=",full_fname)
            #     fname = os.path.join(self.output_dir, "config.xml")
            #     tree = ET.parse(fname)
            #     xml_root = tree.getroot()
            #     self.xmin = float(xml_root.find(".//x_min").text)
            #     self.xmax = float(xml_root.find(".//x_max").text)
            #     dx = float(xml_root.find(".//dx").text)
            #     self.ymin = float(xml_root.find(".//y_min").text)
            #     self.ymax = float(xml_root.find(".//y_max").text)
            #     dy = float(xml_root.find(".//dy").text)
            #     self.numx =  math.ceil( (self.xmax - self.xmin) / dx)
            #     self.numy =  math.ceil( (self.ymax - self.ymin) / dy)

            try:
                xgrid = M[0, :].reshape(self.numy, self.numx)
                ygrid = M[1, :].reshape(self.numy, self.numx)
            except:
                print("substrates.py: mismatched mesh size for reshape: numx,numy=",self.numx, self.numy)
                pass
#                xgrid = M[0, :].reshape(self.numy, self.numx)
#                ygrid = M[1, :].reshape(self.numy, self.numx)

            num_contours = 15
            levels = MaxNLocator(nbins=num_contours).tick_values(self.cmap_min.value, self.cmap_max.value)
            contour_ok = True
            if (self.cmap_fixed_toggle.value):
                try:
                    # substrate_plot = main_ax.contourf(xgrid, ygrid, M[self.field_index, :].reshape(self.numy, self.numx), levels=levels, extend='both', cmap=self.field_cmap.value, fontsize=self.fontsize)
                    substrate_plot = plt.contourf(xgrid, ygrid, M[self.field_index, :].reshape(self.numy, self.numx), levels=levels, extend='both', cmap=self.field_cmap.value, fontsize=self.fontsize)
                except:
                    contour_ok = False
                    # print('got error on contourf 1.')
            else:    
                try:
                    # substrate_plot = main_ax.contourf(xgrid, ygrid, M[self.field_index, :].reshape(self.numy,self.numx), num_contours, cmap=self.field_cmap.value)
                    substrate_plot = plt.contourf(xgrid, ygrid, M[self.field_index, :].reshape(self.numy,self.numx), num_contours, cmap=self.field_cmap.value)
                except:
                    contour_ok = False
                    # print('got error on contourf 2.')

            if (contour_ok):
                # main_ax.set_title(self.title_str, fontsize=self.fontsize)
                plt.title(self.title_str, fontsize=self.fontsize)
                # main_ax.tick_params(labelsize=self.fontsize)
            # cbar = plt.colorbar(my_plot)
                # cbar = self.fig.colorbar(substrate_plot, ax=main_ax)
                cbar = self.fig.colorbar(substrate_plot)
                cbar.ax.tick_params(labelsize=self.fontsize)
                # cbar = main_ax.colorbar(my_plot)
                # cbar.ax.tick_params(labelsize=self.fontsize)
            # axes_min = 0
            # axes_max = 2000

            # main_ax.set_xlim([self.xmin, self.xmax])
            # main_ax.set_ylim([self.ymin, self.ymax])
            plt.xlim(self.xmin, self.xmax)
            plt.ylim(self.ymin, self.ymax)

            # if (frame == 0):  # maybe allow substrate grid display later
            #     xs = np.linspace(self.xmin,self.xmax,self.numx)
            #     ys = np.linspace(self.ymin,self.ymax,self.numy)
            #     hlines = np.column_stack(np.broadcast_arrays(xs[0], ys, xs[-1], ys))
            #     vlines = np.column_stack(np.broadcast_arrays(xs, ys[0], xs, ys[-1]))
            #     grid_lines = np.concatenate([hlines, vlines]).reshape(-1, 2, 2)
            #     line_collection = LineCollection(grid_lines, color="gray", linewidths=0.5)
            #     # ax = main_ax.gca()
            #     main_ax.add_collection(line_collection)
            #     # ax.set_xlim(xs[0], xs[-1])
            #     # ax.set_ylim(ys[0], ys[-1])


        # Now plot the cells (possibly on top of the substrate)
        if (self.cells_toggle.value):
            if (not self.substrates_toggle.value):
                # self.fig = plt.figure(figsize=(12, 12))
                self.fig = plt.figure(figsize=(self.figsize_width_svg, self.figsize_height_svg))
            # self.plot_svg(frame)
            self.svg_frame = frame
            # print('plot_svg with frame=',self.svg_frame)
            self.plot_svg(self.svg_frame)
Exemple #22
0
class PaginationWidget(HBox):
    """A pagination widget that enables setting page number and the number of rows per page."""

    # number of rows in a page by default
    DEFAULT_LIMIT = 50
    # max number of rows to avoid performance problems
    MAX_LIMIT = 100

    # current page number
    page = Int()
    # number of rows per page.
    limit = Int()

    def __init__(self, nb_rows, limit=50, *args, **kwargs):
        """
        Parameters
        ----------
        nb_rows: int
            total number of rows in the result set.
        limit: int
            number of rows to display in a page.

        """
        super().__init__(*args, **kwargs)

        self.__nb_rows = nb_rows if nb_rows else 1
        self.page = 1
        self.limit = (limit if limit and limit > 0
                      and limit < PaginationWidget.MAX_LIMIT else
                      PaginationWidget.DEFAULT_LIMIT)

        self.layout.width = "400px"

        if nb_rows <= limit:
            self.layout.visibility = "hidden"
        else:
            self.layout.visibility = "visible"

            nb_pages = self._get_nb_pages(self.limit)

            # widget to set page number
            self.__page_widget = BoundedIntText(
                value=self.page,
                min=1,
                max=nb_pages,
                step=1,
                continuous_update=True,
                description="page",
                description_tooltip="Current page",
                disabled=False,
                style={"description_width": "30px"},
                layout=Layout(width="90px", max_width="90px"),
            )

            # widget to display total number of pages.
            self.__label_slash = Label(value=f"/ {nb_pages}",
                                       layout=Layout(width="60px"))

            # widget to set limit
            self.__limit_widget = BoundedIntText(
                value=self.limit,
                min=1,
                max=PaginationWidget.MAX_LIMIT,
                step=1,
                continuous_update=True,
                description="rows",
                description_tooltip=
                f"Number of rows per page. Max. possible: {PaginationWidget.MAX_LIMIT}",
                disabled=False,
                style={"description_width": "30px"},
                layout=Layout(width="90px", max_width="90px"),
            )

            self.__page_widget.observe(self._page_widget_changed,
                                       names="value")
            self.__limit_widget.observe(self._limit_widget_changed,
                                        names="value")

            self.children = [
                self.__page_widget,
                self.__label_slash,
                self.__limit_widget,
            ]

    def _get_nb_pages(self, limit):
        return ceil(self.__nb_rows / limit)

    def _page_widget_changed(self, change):
        self.page = change["new"]

    def _limit_widget_changed(self, change):
        new_limit = change["new"]
        # update limit
        self.limit = new_limit
        self.page = 1

        nb_pages = self._get_nb_pages(new_limit)

        # update page widget
        self.__page_widget.max = nb_pages
        self.__page_widget.value = 1

        # update label slash widget
        self.__label_slash.value = f"/ {nb_pages}"