示例#1
0
class MultiClassAnnotator:
    def __init__(self,
                 dataset_voc,
                 output_path_statistic,
                 name,
                 metrics,  # this is an array of the following ['occ', 'truncated', 'side', 'part']
                 show_name=True,
                 show_axis=False,
                 fig_size=(10, 10),
                 buttons_vertical=False,
                 image_display_function=None,
                 classes_to_annotate=None
                 ):

        if dataset_voc.annotations_gt is None:  # in case that dataset_voc has not been called
            dataset_voc.load()

        self.dataset_voc = dataset_voc
        self.metrics = metrics
        self.show_axis = show_axis
        self.name = name
        self.show_name = show_name
        if output_path_statistic is None:
            output_path_statistic = self.dataset_voc.dataset_root_param

        if classes_to_annotate is None:  # if classes_to_annotate is None, all the classes would be annotated
            self.classes_to_annotate = self.dataset_voc.objnames_all  # otherwise, the only the classes in the list

        self.output_path = output_path_statistic
        self.file_path = os.path.join(self.output_path, self.name + ".json")
        self.mapping, self.dataset = self.__create_results_dict(self.file_path, metrics)

        self.objects = dataset_voc.get_objects_index(self.classes_to_annotate)
        self.current_pos = 0

        self.max_pos = len(self.objects) - 1

        self.fig_size = fig_size
        self.buttons_vertical = buttons_vertical

        if image_display_function is None:
            self.image_display_function = self.__show_image
        else:
            self.image_display_function = image_display_function

        # create buttons
        self.previous_button = self.__create_button("Previous", (self.current_pos == 0), self.__on_previous_clicked)
        self.next_button = self.__create_button("Next", (self.current_pos == self.max_pos), self.__on_next_clicked)
        self.save_button = self.__create_button("Save", False, self.__on_save_clicked)
        self.save_function = self.__save_function  # save_function
        self.current_image = {}
        buttons = [self.previous_button, self.next_button]
        buttons.append(self.save_button)

        label_total = Label(value='/ {}'.format(len(self.objects)))
        self.text_index = BoundedIntText(value=1, min=1, max=len(self.objects))
        self.text_index.layout.width = '80px'
        self.text_index.layout.height = '35px'
        self.text_index.observe(self.__selected_index)
        self.out = Output()
        self.out.add_class(name)

        metrics_labels = self.dataset_voc.read_label_metrics_name_from_file()
        metrics_labels['truncated'] = {'0': 'False', '1': 'True'}
        self.metrics_labels = metrics_labels

        self.checkboxes = {}
        self.radiobuttons = {}

        output_layout = []
        for m_i, m_n in enumerate(self.metrics):
            if 'parts' == m_n:  # Special case
                continue

            if m_n in ['truncated', 'occ']:  # radiobutton
                self.radiobuttons[m_n] = RadioButtons(options=[i for i in metrics_labels[m_n].values()],
                                                      disabled=False,
                                                      indent=False)
            else:  # checkbox
                self.checkboxes[m_n] = [Checkbox(False, description='{}'.format(metrics_labels[m_n][i]),
                                                 indent=False) for i in metrics_labels[m_n].keys()]

        self.check_radio_boxes_layout = {}
        for cb_k, cb_i in self.checkboxes.items():
            for cb in cb_i:
                cb.layout.width = '180px'
                cb.observe(self.__checkbox_changed)
            html_title = HTML(value="<b>" + cb_k + "</b>")
            self.check_radio_boxes_layout[cb_k] = VBox(children=[cb for cb in cb_i])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[cb_k]]))

        for rb_k, rb_v in self.radiobuttons.items():
            rb_v.layout.width = '180px'
            rb_v.observe(self.__checkbox_changed)
            html_title = HTML(value="<b>" + rb_k + "</b>")
            self.check_radio_boxes_layout[rb_k] = VBox([rb_v])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[rb_k]]))

        #create an output for the future dynamic SIDES_PARTS attributes
        self.dynamic_output_for_parts = Output()
        html_title = HTML(value="<b>" + "Parts" + "</b>")
        output_layout.append(VBox([html_title, self.dynamic_output_for_parts]))

        self.all_widgets = VBox(children=
                                [HBox([self.text_index, label_total]),
                                 HBox(buttons),
                                 HBox(output_layout),
                                 self.out])

        ## loading js library to perform html screenshots
        j_code = """
                require.config({
                    paths: {
                        html2canvas: "https://html2canvas.hertzen.com/dist/html2canvas.min"
                    }
                });
            """
        display(Javascript(j_code))

    def __create_results_dict(self, file_path, cc):
        mapping = {}
        # mapping["categories_id"] = {}
        # mapping["categories_name"] = {}
        # mapping["objects"] = {}
        #
        # if not os.path.exists(file_path):
        #     dataset = {}
        #     dataset['categories'] = []
        #     dataset["images"] = []
        #     for index, c in enumerate(cc):
        #         category = {}
        #         category["supercategory"] = c
        #         category["name"] = c
        #         category["id"] = index
        #         dataset["categories"].append(category)
        # else:
        #     with open(file_path, 'r') as classification_file:
        #         dataset = json.load(classification_file)
        #     for index, img in enumerate(dataset['images']):
        #         mapping['images'][img["path"]] = index
        #
        # for index, c in enumerate(dataset['categories']):
        #     mapping['categories_id'][c["id"]] = c["name"]
        #     mapping['categories_name'][c["name"]] = c["id"]
        # index_categories = len(dataset['categories']) - 1
        #
        # for c in cc:
        #     if not c in mapping['categories_name'].keys():
        #         mapping['categories_id'][index_categories] = c
        #         mapping['categories_name'][c] = index_categories
        #         category = {}
        #         category["supercategory"] = c
        #         category["name"] = c
        #         category["id"] = index_categories
        #         dataset["categories"].append(category)
        #         index_categories += 1

        return {}, {}

    def __checkbox_changed(self, b):
        if b['owner'].value is None or b['name'] != 'value':
            return

        class_name = b['owner'].description
        value = b['owner'].value

        current_index = self.mapping["images"][self.objects[self.current_pos]]
        class_index = self.mapping["categories_name"][class_name]
        if not class_index in self.dataset["images"][current_index]["categories"] and value:
            self.dataset["images"][current_index]["categories"].append(class_index)
        if class_index in self.dataset["images"][current_index]["categories"] and not value:
            self.dataset["images"][current_index]["categories"].remove(class_index)

    def __create_button(self, description, disabled, function):
        button = Button(description=description)
        button.disabled = disabled
        button.on_click(function)
        return button

    def __show_image(self, image_record, obj_num):
        #   read img from path and show it
        path_img = os.path.join(self.output_path, 'JPEGImages', image_record['filename'])
        img = Image.open(path_img)
        if self.show_name:
            print(os.path.basename(path_img) + '. Class: {} [class_id={}]'.format(
                self.objects[self.current_pos]['class_name'],
                self.objects[self.current_pos]['class_id']))
        plt.figure(figsize=self.fig_size)

        if not self.show_axis:
            plt.axis('off')
        plt.imshow(img)

        # draw the bbox from the object onum
        ax = plt.gca()
        class_colors = cm.rainbow(np.linspace(0, 1, len(self.dataset_voc.objnames_all)))

        [bbox_x1, bbox_y1, bbox_x2, bbox_y2] = image_record['objects'][obj_num]['bbox']
        poly = [[bbox_x1, bbox_y1], [bbox_x1, bbox_y2], [bbox_x2, bbox_y2],
                [bbox_x2, bbox_y1]]
        np_poly = np.array(poly).reshape((4, 2))

        object_class_name = image_record['objects'][obj_num]['class']
        c = class_colors[self.dataset_voc.objnames_all.index(object_class_name)]

        # draws the bbox
        ax.add_patch(
            Polygon(np_poly, linestyle='-', facecolor=(c[0], c[1], c[2], 0.0),
                    edgecolor=(c[0], c[1], c[2], 1.0), linewidth=2))

        #  write the class name in bbox
        ax.text(x=bbox_x1, y=bbox_y1, s=object_class_name, color='white', fontsize=9, horizontalalignment='left',
                verticalalignment='top',
                bbox=dict(facecolor=(c[0], c[1], c[2], 0.5)))
        plt.show()

    def save_state(self):
        with open(self.file_path, 'w') as output_file:
            json.dump(self.dataset, output_file, indent=4)

    def __save_function(self, image_path):
        img_name = os.path.basename(image_path).split('.')[0]
        j_code = """
            require(["html2canvas"], function(html2canvas) {
                var element = $(".p-Widget.jupyter-widgets-output-area.output_wrapper.$it_name$")[0];
                console.log(element);
                 html2canvas(element).then(function (canvas) { 
                    var myImage = canvas.toDataURL(); 
                    var a = document.createElement("a"); 
                    a.href = myImage; 
                    a.download = "$img_name$.png"; 
                    a.click(); 
                    a.remove(); 
                });
            });
            """
        j_code = j_code.replace('$it_name$', self.name)
        j_code = j_code.replace('$img_name$', img_name)
        tmp_out = Output()
        with tmp_out:
            display(Javascript(j_code))
            tmp_out.clear_output()

    def __perform_action(self):
        # generates statistic saved in json file
        # print(self.objects[self.current_class])
        # if not self.objects[self.current_class][self.current_pos]['path'] in self.mapping["images"].keys():
        #     image = {}
        #     image["path"] = self.objects[self.current_class][self.current_pos]
        #     image["id"] = len(self.mapping["images"]) + 1
        #     image["categories"] = []
        #     self.dataset["images"].append(image)
        #     self.mapping["images"][image["path"]] = len(self.dataset["images"]) - 1
        # current_index = self.mapping["images"][self.objects[self.current_class][self.current_pos]]
        self.next_button.disabled = (self.current_pos == self.max_pos)
        self.previous_button.disabled = (self.current_pos == 0)

        current_class_id = self.objects[self.current_pos]['class_id']
        current_gt_id = self.objects[self.current_pos]['gt_id']

        # start to check for each type of metric
        if 'occ' in self.radiobuttons.keys():
            cb = self.radiobuttons['occ']
            rb_options = self.radiobuttons['occ'].options
            cb.unobserve(self.__checkbox_changed)
            cb.value = None #clear the current value
            if self.dataset_voc.annotations_gt['gt'][current_class_id]['details'][
                current_gt_id]:  # check if it is empty
                occ_level = self.dataset_voc.annotations_gt['gt'][current_class_id]['details'][current_gt_id][
                    'occ_level']
                cb.value = rb_options[occ_level - 1]
            cb.observe(self.__checkbox_changed)

        if 'truncated' in self.radiobuttons.keys(): #since this works for PASCAL VOC there's always a truncation value
            cb = self.radiobuttons['truncated']
            rb_options = self.radiobuttons['truncated'].options
            cb.unobserve(self.__checkbox_changed)
            cb.value = rb_options[
                int(self.dataset_voc.annotations_gt['gt'][current_class_id]['istrunc'][current_gt_id] == True)]
            cb.observe(self.__checkbox_changed)

        if 'views' in self.checkboxes.keys():
            for cb_i, cb in enumerate(self.checkboxes['views']):
                cb.unobserve(self.__checkbox_changed)
                cb.value = False #clear the value
                if self.dataset_voc.annotations_gt['gt'][current_class_id]['details'][current_gt_id]:
                    # check if it is empty
                    cb.value = bool(self.dataset_voc.annotations_gt['gt'][current_class_id]['details'][current_gt_id][
                                        'side_visible'][cb.description])
                cb.observe(self.__checkbox_changed)

        #need to create the output first for the buttons
        with self.dynamic_output_for_parts:
            self.dynamic_output_for_parts.clear_output()
            if self.objects[self.current_pos]['class_name'] in self.metrics_labels['parts']:
                self.cb_parts = [Checkbox(False, description='{}'.format(i), indent=False) for i in self.metrics_labels['parts'][self.objects[self.current_pos]['class_name']]]
            else:
                self.cb_parts = [HTML(value="No PARTS defined in Conf file")]
            display(VBox(children=[cb for cb in self.cb_parts]))

        with self.out:
            self.out.clear_output()
            image_record, obj_num = self.__get_image_record()
            self.image_display_function(image_record, obj_num)

        self.text_index.unobserve(self.__selected_index)
        self.text_index.value = self.current_pos + 1
        self.text_index.observe(self.__selected_index)

    def __get_image_record(self):
        current_class_id = self.objects[self.current_pos]['class_id']
        current_gt_id = self.objects[self.current_pos]['gt_id']

        obj_num = self.dataset_voc.annotations_gt['gt'][current_class_id]['onum'][current_gt_id]
        index_row = self.dataset_voc.annotations_gt['gt'][current_class_id]['rnum'][current_gt_id]
        r = self.dataset_voc.annotations_gt['rec'][index_row]
        return r, obj_num

    def __on_previous_clicked(self, b):
        self.save_state()
        self.current_pos -= 1
        self.__perform_action()

    def __on_next_clicked(self, b):
        self.save_state()
        self.current_pos += 1
        self.__perform_action()

    def __on_save_clicked(self, b):
        self.save_state()
        image_record, _ = self.__get_image_record()
        path_img = os.path.join(self.output_path, 'JPEGImages', image_record['filename'])
        self.save_function(path_img)

    def __selected_index(self, t):
        if t['owner'].value is None or t['name'] != 'value':
            return
        self.current_pos = t['new'] - 1
        self.__perform_action()

    def start_classification(self):
        if self.max_pos < self.current_pos:
            print("No available images")
            return
        display(self.all_widgets)
        self.__perform_action()

    def print_statistics(self):
        counter = defaultdict(int)
        for record in self.dataset["images"]:
            for c in record["categories"]:
                counter[c] += 1
        table = []
        for c in counter:
            table.append([self.mapping["categories_id"][c], counter[c]])
        table = sorted(table, key=lambda x: x[0])
        print(tabulate(table, headers=['Class name', 'Annotated images']))
示例#2
0
    class __Databases(object):
        """Database browser implementation
            
        Args:
            spark (SparkSession): Spark Session object
        """

        def __init__(self, spark):
            self.spark = spark

        def create(self):
            """Create the sidecar view"""
            self.sc = Sidecar(
                title="Databases-%s" % os.environ["DBJL_CLUSTER"].split("-")[-1],
                layout=Layout(width="300px"),
            )
            self.refresh = Button(description="refresh")
            self.refresh.on_click(self.on_refresh)
            self.output = Output(
                layout=Layout(
                    height="600px", width="320px", overflow_x="scroll", overflow_y="scroll"
                )
            )
            self.output.add_class("db-detail")
            self.selects = []
            self.accordion = Accordion(children=[])

            with self.sc:
                display(VBox([self.refresh, self.accordion, self.output]))

            self.update()
            self.set_css()

        def on_refresh(self, b):
            """Refresh handler
            
            Args:
                b (ipywidgets.Button): clicked button
            """
            self.selects = []
            self.update()

        def update(self):
            """Update the view when an element was selected"""
            tables = {}
            for obj in self.spark.sql("show tables").rdd.collect():
                db = obj[0]
                table = obj[1]
                temp = obj[2]
                if temp and db == "":
                    db = "temp"
                if tables.get(db, None) is None:
                    tables[db] = []
                if temp:
                    tables[db].append("%s (temp)" % table)
                else:
                    tables[db].append(table)

            for db in sorted(tables.keys()):
                select = Select(options=[""] + sorted(tables[db]), disabled=False)
                select.observe(self.on_click(db, self), names="value")
                self.selects.append(select)
            self.accordion.children = self.selects
            for i, db in enumerate(sorted(tables.keys())):
                self.accordion.set_title(i, db)

        def on_click(self, db, parent):
            """Click handler providing db and parent as context
            
            Args:
                db (str): database name
                parent (object): parent object
            """

            def f(change):
                if change["old"] is not None:
                    parent.output.clear_output()
                    with parent.output:
                        if db == "temp":
                            table = change["new"]
                        else:
                            table = "%s.%s" % (db, change["new"])
                        if table.endswith(" (temp)"):
                            table = table[:-7]

                        try:
                            schema = parent.spark.sql("describe extended %s" % table)
                            rows = int(parent.spark.conf.get("spark.sql.repl.eagerEval.maxNumRows"))
                            parent.spark.conf.set("spark.sql.repl.eagerEval.maxNumRows", 1000)
                            display(schema)
                            parent.spark.conf.set("spark.sql.repl.eagerEval.maxNumRows", rows)
                        except:
                            print("schema cannot be accessed, table most probably broken")

            return f

        def close(self):
            """Close view"""
            self.selects = []
            self.sc.close()

        def set_css(self):
            """Set CSS"""
            display(
                HTML(
                    """
            <style>
            .db-detail .p-Widget {
            overflow: visible;
            }
            </style>
            """
                )
            )
示例#3
0
class Annotator:

    def __init__(self,
                 dataset,
                 metrics_and_values,  # properties{ name: ([possible values], just_one_value = True)}
                 output_path=None,
                 show_name=True,
                 show_axis=False,
                 fig_size=(10, 10),
                 buttons_vertical=False,
                 image_display_function=None,
                 classes_to_annotate=None
                 ):

        self.dataset_orig = dataset
        self.metrics_and_values = metrics_and_values
        self.show_axis = show_axis
        self.name = (self.dataset_orig.dataset_root_param.split('/')[-1]).split('.')[0]  # get_original_file_name
        self.show_name = show_name
        if output_path is None:
            splitted_array = self.dataset_orig.dataset_root_param.split('/')
            n = len(splitted_array)
            self.output_directory = os.path.join(*(splitted_array[0:n - 1]))
        else:
            self.output_directory = output_path

        if classes_to_annotate is None:  # if classes_to_annotate is None, all the classes would be annotated
            self.classes_to_annotate = self.dataset_orig.get_categories_names()  # otherwise, the only the classes in the list

        self.file_path_for_json = os.path.join("/", self.output_directory, self.name + "_ANNOTATED.json")
        print("New dataset with meta_annotations {}".format(self.file_path_for_json))
        self.mapping = None
        self.objects = self.dataset_orig.get_annotations_from_class_list(self.classes_to_annotate)
        self.max_pos = len(self.objects) - 1
        self.current_pos = 0
        self.mapping, self.dataset_annotated = self.__create_results_dict(self.file_path_for_json)

        self.fig_size = fig_size
        self.buttons_vertical = buttons_vertical

        if image_display_function is None:
            self.image_display_function = self.__show_image
        else:
            self.image_display_function = image_display_function

        # create buttons
        self.previous_button = self.__create_button("Previous", (self.current_pos == 0), self.__on_previous_clicked)
        self.next_button = self.__create_button("Next", (self.current_pos == self.max_pos), self.__on_next_clicked)
        self.save_button = self.__create_button("Download", False, self.__on_save_clicked)
        self.save_function = self.__save_function  # save_function
        self.current_image = {}
        buttons = [self.previous_button, self.next_button]
        buttons.append(self.save_button)

        label_total = Label(value='/ {}'.format(len(self.objects)))
        self.text_index = BoundedIntText(value=1, min=1, max=len(self.objects))
        self.text_index.layout.width = '80px'
        self.text_index.layout.height = '35px'
        self.text_index.observe(self.__selected_index)
        self.out = Output()
        self.out.add_class("my_canvas_class")

        self.checkboxes = {}
        self.radiobuttons = {}
        self.bounded_text = {}

        output_layout = []
        for k_name, v in self.metrics_and_values.items():
            if MetaPropertiesTypes.META_PROP_UNIQUE == v[1]:  # radiobutton
                self.radiobuttons[k_name] = RadioButtons(name=k_name, options=v[0],
                                                         disabled=False,
                                                         indent=False)
            elif MetaPropertiesTypes.META_PROP_COMPOUND == v[1]:  # checkbox
                self.checkboxes[k_name] = [Checkbox(False, description='{}'.format(prop_name),
                                                    indent=False, name=k_name) for prop_name in v[0]]
            elif MetaPropertiesTypes.META_PROP_CONTINUE == v[1]:
                self.bounded_text[k_name] = BoundedFloatText(value=v[0][0], min=v[0][0], max=v[0][1])

        self.check_radio_boxes_layout = {}

        for rb_k, rb_v in self.radiobuttons.items():
            rb_v.layout.width = '180px'
            rb_v.observe(self.__checkbox_changed)
            rb_v.add_class(rb_k)
            html_title = HTML(value="<b>" + rb_k + "</b>")
            self.check_radio_boxes_layout[rb_k] = VBox([rb_v])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[rb_k]]))

        for cb_k, cb_i in self.checkboxes.items():
            for cb in cb_i:
                cb.layout.width = '180px'
                cb.observe(self.__checkbox_changed)
                cb.add_class(cb_k)
            html_title = HTML(value="<b>" + cb_k + "</b>")
            self.check_radio_boxes_layout[cb_k] = VBox(children=[cb for cb in cb_i])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[cb_k]]))

        for bf_k, bf in self.bounded_text.items():
            bf.layout.width = '80px'
            bf.layout.height = '35px'
            bf.observe(self.__checkbox_changed)
            bf.add_class(bf_k)
            html_title = HTML(value="<b>" + bf_k + "</b>")
            self.check_radio_boxes_layout[bf_k] = VBox([bf])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[bf_k]]))

        self.all_widgets = VBox(children=
                                [HBox([self.text_index, label_total]),
                                 HBox(buttons),
                                 HBox([self.out,
                                 VBox(output_layout)])])

        ## loading js library to perform html screenshots
        j_code = """
                require.config({
                    paths: {
                        html2canvas: "https://html2canvas.hertzen.com/dist/html2canvas.min"
                    }
                });
            """
        display(Javascript(j_code))

    def __add_annotation_to_mapping(self, ann):
        cat_name = self.dataset_orig.get_category_name_from_id(ann['category_id'])
        if ann['id'] not in self.mapping['annotated_ids']:  # o nly adds category counter if it was not annotate
            self.mapping['categories_counter'][cat_name] += 1

        self.mapping['annotated_ids'].add(ann['id'])

    def __create_results_dict(self, file_path):
        mapping = {}
        mapping["annotated_ids"] = set()  # key: object_id from dataset, values=[(annotations done)]
        mapping["categories_counter"] = dict.fromkeys([c for c in self.dataset_orig.get_categories_names()], 0)
        self.mapping = mapping

        if not os.path.exists("/" + file_path): #it does exist __ANNOTATED in the output directory
            with open(self.dataset_orig.dataset_root_param, 'r') as input_json_file:
                dataset_annotated = json.load(input_json_file)
                input_json_file.close()
            #take the same metaproperties already in file if it's not empty (like a new file)
            meta_prop = dataset_annotated['meta_properties'] if 'meta_properties' in dataset_annotated.keys() else []

            # adds the new annotations categories to dataset if it doesn't exist
            for k_name, v in self.metrics_and_values.items():
                new_mp_to_append = {
                    "name": k_name,
                    "type": v[1].value,
                    "values": [p for p in v[0]].sort()
                }

                names_prop_in_file = {m_p['name']: m_p for m_i, m_p in enumerate(dataset_annotated[
                                                                                     'meta_properties'])} if 'meta_properties' in dataset_annotated.keys() else None

                if 'meta_properties' not in dataset_annotated.keys():  # it is a new file
                    meta_prop.append(new_mp_to_append)
                    dataset_annotated['meta_properties'] = []

                elif names_prop_in_file is not None and k_name not in names_prop_in_file.keys():
                    # if there is a property with the same in meta_properties, it must be the same structure as the one proposed
                    meta_prop.append(new_mp_to_append)
                    self.__update_annotation_counter_and_current_pos(dataset_annotated)

                elif names_prop_in_file is not None and k_name in names_prop_in_file.keys() and \
                        names_prop_in_file[k_name] == new_mp_to_append:
                    #we don't append because it's already there
                    self.__update_annotation_counter_and_current_pos(dataset_annotated)

                else:
                    raise NameError("An annotation with the same name {} "
                                    "already exist in dataset {}, and it has different structure. Check properties.".format(
                        k_name, self.dataset_orig.dataset_root_param))

                #if k_name is in name_props_in_file and it's the same structure. No update is done.
            dataset_annotated['meta_properties'] = dataset_annotated['meta_properties'] + meta_prop
        else:
            with open(file_path, 'r') as classification_file:
                dataset_annotated = json.load(classification_file)
                classification_file.close()

            self.__update_annotation_counter_and_current_pos(dataset_annotated)
        return mapping, dataset_annotated

    def __update_annotation_counter_and_current_pos(self, dataset_annotated):
        prop_names = set(self.metrics_and_values.keys())
        last_ann_id = self.objects[0]['id']
        for ann in dataset_annotated['annotations']:
            if prop_names.issubset(
                    set(ann.keys())):  # we consider that it was annotated when all the props are subset of keys
                self.__add_annotation_to_mapping(ann)
                last_ann_id = ann['id']  # to get the last index in self.object so we can update the current pos in the last one anotated
            else:
                break
        self.current_pos = next(i for i, a in enumerate(self.objects) if a['id'] == last_ann_id)

    def __checkbox_changed(self, b):
        if b['owner'].value is None or b['name'] != 'value':
            return

        class_name = b['owner'].description
        value = b['owner'].value
        annotation_name = b['owner']._dom_classes[0]

        ann_id = self.objects[self.current_pos]['id']
        # image_id = self.objects[self.current_pos]['image_id']
        for ann in self.dataset_annotated['annotations']:
            if ann['id'] == ann_id:
                break
        if self.metrics_and_values[annotation_name][1] in [MetaPropertiesTypes.META_PROP_COMPOUND]:
            if annotation_name not in ann.keys():
                ann[annotation_name] = {p: 0 for p in self.metrics_and_values[annotation_name][0]}
            ann[annotation_name][class_name] = int(value)
        else:  # UNIQUE VALUE
            ann[annotation_name] = value

    def __create_button(self, description, disabled, function):
        button = Button(description=description)
        button.disabled = disabled
        button.on_click(function)
        return button

    def __show_image(self, image_record, ann_key):
        #   read img from path and show it
        path_img = os.path.join(self.dataset_orig.images_abs_path, image_record['file_name'])
        img = Image.open(path_img)
        if self.show_name:
            print(os.path.basename(path_img) + '. Class: {} [class_id={}]'.format(
                self.dataset_orig.get_category_name_from_id(self.objects[self.current_pos]['category_id']),
                self.objects[self.current_pos]['category_id']))
        plt.figure(figsize=self.fig_size)

        if not self.show_axis:
            plt.axis('off')
        plt.imshow(img)

        # draw the bbox from the object onum
        ax = plt.gca()
        class_colors = cm.rainbow(np.linspace(0, 1, len(self.dataset_orig.get_categories_names())))

        annotation = self.dataset_orig.coco_lib.anns[ann_key]
        object_class_name = self.dataset_orig.get_category_name_from_id(annotation['category_id'])
        c = class_colors[self.dataset_orig.get_categories_names().index(object_class_name)]
        if not self.dataset_orig.is_segmentation:
            [bbox_x1, bbox_y1, diff_x, diff_y] = annotation['bbox']
            bbox_x2 = bbox_x1 + diff_x
            bbox_y2 = bbox_y1 + diff_y
            poly = [[bbox_x1, bbox_y1], [bbox_x1, bbox_y2], [bbox_x2, bbox_y2],
                    [bbox_x2, bbox_y1]]
            np_poly = np.array(poly).reshape((4, 2))

            # draws the bbox
            ax.add_patch(
                Polygon(np_poly, linestyle='-', facecolor=(c[0], c[1], c[2], 0.0),
                        edgecolor=(c[0], c[1], c[2], 1.0), linewidth=2))
        else:
            seg_points = annotation['segmentation']
            for pol in seg_points:
                poly = [[float(pol[i]), float(pol[i + 1])] for i in range(0, len(pol), 2)]
                np_poly = np.array(poly)  # .reshape((len(pol), 2))
                ax.add_patch(
                    Polygon(np_poly, linestyle='-', facecolor=(c[0], c[1], c[2], 0.25),
                            edgecolor=(c[0], c[1], c[2], 1.0), linewidth=2))
            # set the first XY point for printing the text
            bbox_x1 = seg_points[0][0];
            bbox_y1 = seg_points[0][1]

        #  write the class name in bbox
        ax.text(x=bbox_x1, y=bbox_y1, s=object_class_name, color='white', fontsize=9, horizontalalignment='left',
                verticalalignment='top',
                bbox=dict(facecolor=(c[0], c[1], c[2], 0.5)))
        plt.show()

    def save_state(self):
        w = SafeWriter(os.path.join(self.file_path_for_json), "w")
        w.write(json.dumps(self.dataset_annotated))
        w.close()
        self.__add_annotation_to_mapping(next(ann_dat for ann_dat in self.dataset_annotated['annotations']
                                              if ann_dat['id'] == self.objects[self.current_pos]['id']))

    def __save_function(self, image_path):
        img_name = os.path.basename(image_path).split('.')[0]
        j_code = """
            require(["html2canvas"], function(html2canvas) {
                var element = $(".p-Widget.jupyter-widgets-output-area.output_wrapper.$it_name$")[0];
                console.log(element);
                 html2canvas(element).then(function (canvas) { 
                    var myImage = canvas.toDataURL(); 
                    var a = document.createElement("a"); 
                    a.href = myImage; 
                    a.download = "$img_name$.png"; 
                    a.click(); 
                    a.remove(); 
                });
            });
            """
        j_code = j_code.replace('$it_name$', "my_canvas_class")
        j_code = j_code.replace('$img_name$', img_name)
        tmp_out = Output()
        with tmp_out:
            display(Javascript(j_code))
            tmp_out.clear_output()

    def __perform_action(self):
        self.next_button.disabled = (self.current_pos == self.max_pos)
        self.previous_button.disabled = (self.current_pos == 0)

        current_class_id = self.objects[self.current_pos]['category_id']
        current_ann = next(
            a for a in self.dataset_annotated['annotations'] if a['id'] == self.objects[self.current_pos]['id'])
        # current_gt_id = self.objects[self.current_pos]['gt_id']

        for m_k, m_v in self.metrics_and_values.items():
            if m_v[1] == MetaPropertiesTypes.META_PROP_UNIQUE:  # radiobutton
                self.radiobuttons[m_k].unobserve(self.__checkbox_changed)
                self.radiobuttons[m_k].value = current_ann[m_k] if m_k in current_ann.keys() else None
                self.radiobuttons[m_k].observe(self.__checkbox_changed)
            elif m_v[1] == MetaPropertiesTypes.META_PROP_COMPOUND:  # checkbox
                for cb_i, cb_v in enumerate(self.checkboxes[m_k]):
                    cb_v.unobserve(self.__checkbox_changed)
                    cb_v.value = bool(current_ann[m_k][cb_v.description]) if m_k in current_ann.keys() else False
                    cb_v.observe(self.__checkbox_changed)
            elif m_v[1] == MetaPropertiesTypes.META_PROP_CONTINUE:  # textbound
                self.bounded_text[m_k].unobserve(self.__checkbox_changed)
                self.bounded_text[m_k].value = float(current_ann[m_k]) if m_k in current_ann.keys() else \
                    self.bounded_text[m_k].min
                self.bounded_text[m_k].observe(self.__checkbox_changed)

        with self.out:
            self.out.clear_output()
            image_record, ann_key = self.__get_image_record()
            self.image_display_function(image_record, ann_key)

        self.text_index.unobserve(self.__selected_index)
        self.text_index.value = self.current_pos + 1
        self.text_index.observe(self.__selected_index)

    def __get_image_record(self):
        # current_class_id = self.objects[self.current_pos]['category_id']
        current_image_id = self.objects[self.current_pos]['image_id']

        ann_key = self.objects[self.current_pos]['id']
        img_record = self.dataset_orig.coco_lib.imgs[current_image_id]
        return img_record, ann_key

    def __on_previous_clicked(self, b):
        self.save_state()
        self.current_pos -= 1
        self.__perform_action()

    def __on_next_clicked(self, b):
        self.save_state()
        self.current_pos += 1
        self.__perform_action()

    def __on_save_clicked(self, b):
        self.save_state()
        image_record, _ = self.__get_image_record()
        path_img = os.path.join(self.output_directory, 'JPEGImages', image_record['file_name'])
        self.save_function(path_img)

    def __selected_index(self, t):
        if t['owner'].value is None or t['name'] != 'value':
            return
        self.current_pos = t['new'] - 1
        self.__perform_action()

    def start_annotation(self):
        if self.max_pos < self.current_pos:
            print("No available images")
            return
        display(self.all_widgets)
        self.__perform_action()

    def print_statistics(self):
        table = []
        total = 0
        for c_k, c_number in self.mapping["categories_counter"].items():
            table.append([c_k, c_number])
            total += c_number
        table = sorted(table, key=lambda x: x[0])
        table.append(['Total', '{}/{}'.format(total, len(self.objects))])
        print(tabulate(table, headers=['Class name', 'Annotated objects']))
示例#4
0
class Iterator:
    def __init__(self,
                 images,
                 name="iterator",
                 show_name=True,
                 show_axis=False,
                 show_random=True,
                 fig_size=(10, 10),
                 buttons_vertical=False,
                 image_display_function=None):
        if len(images) == 0:
            raise Exception("No images provided")

        self.show_axis = show_axis
        self.name = name
        self.show_name = show_name
        self.show_random = show_random
        self.images = images
        self.max_pos = len(self.images) - 1
        self.pos = 0
        self.fig_size = fig_size
        self.buttons_vertical = buttons_vertical

        if image_display_function is None:
            self.image_display_function = self.__show_image
        else:
            self.image_display_function = image_display_function

        self.previous_button = self.__create_button("Previous",
                                                    (self.pos == 0),
                                                    self.__on_previous_clicked)
        self.next_button = self.__create_button("Next",
                                                (self.pos == self.max_pos),
                                                self.__on_next_clicked)
        self.save_button = self.__create_button("Save", False,
                                                self.__on_save_clicked)
        self.save_function = self.__save_function  # save_function

        buttons = [self.previous_button, self.next_button]

        if self.show_random:
            self.random_button = self.__create_button("Random", False,
                                                      self.__on_random_clicked)
            buttons.append(self.random_button)

        buttons.append(self.save_button)

        label_total = Label(value='/ {}'.format(len(self.images)))
        self.text_index = BoundedIntText(value=1, min=1, max=len(self.images))
        self.text_index.layout.width = '80px'
        self.text_index.layout.height = '35px'
        self.text_index.observe(self.__selected_index)
        self.out = Output()
        self.out.add_class(name)

        if self.buttons_vertical:
            self.all_widgets = HBox(children=[
                VBox(children=[HBox([self.text_index, label_total])] +
                     buttons), self.out
            ])
        else:
            self.all_widgets = VBox(children=[
                HBox([self.text_index, label_total]),
                HBox(children=buttons), self.out
            ])
        ## loading js library to perform html screenshots
        j_code = """
                require.config({
                    paths: {
                        html2canvas: "https://html2canvas.hertzen.com/dist/html2canvas.min"
                    }
                });
            """
        display(Javascript(j_code))

    def __create_button(self, description, disabled, function):
        button = Button(description=description)
        button.disabled = disabled
        button.on_click(function)
        return button

    def __show_image(self, image_path, index):
        img = Image.open(image_path)
        if self.show_name:
            print(os.path.basename(image_path))
        plt.figure(figsize=self.fig_size)
        if not self.show_axis:
            plt.axis('off')
        plt.imshow(img)
        plt.show()

    def __save_function(self, image_path, index):
        img_name = os.path.basename(image_path).split('.')[0]
        j_code = """
            require(["html2canvas"], function(html2canvas) {
                var element = $(".p-Widget.jupyter-widgets-output-area.output_wrapper.$it_name$")[0];
                console.log(element);
                 html2canvas(element).then(function (canvas) { 
                    var myImage = canvas.toDataURL(); 
                    var a = document.createElement("a"); 
                    a.href = myImage; 
                    a.download = "$img_name$.png"; 
                    a.click(); 
                    a.remove(); 
                });
            });
            """
        j_code = j_code.replace('$it_name$', self.name)
        j_code = j_code.replace('$img_name$', img_name)
        tmp_out = Output()
        with tmp_out:
            display(Javascript(j_code))
            tmp_out.clear_output()

    def __on_next_clicked(self, b):
        self.pos += 1
        self.__perform_action(self.pos, self.max_pos)

    def __on_save_clicked(self, b):
        self.save_function(self.images[self.pos], self.pos)

    def __perform_action(self, index, max_pos):
        self.next_button.disabled = (index == max_pos)
        self.previous_button.disabled = (index == 0)

        with self.out:
            self.out.clear_output()
        with self.out:
            self.image_display_function(self.images[index], index)

        self.text_index.unobserve(self.__selected_index)
        self.text_index.value = index + 1
        self.text_index.observe(self.__selected_index)

    def __on_previous_clicked(self, b):
        self.pos -= 1
        self.__perform_action(self.pos, self.max_pos)

    def __on_random_clicked(self, b):
        self.pos = random.randint(0, self.max_pos)
        self.__perform_action(self.pos, self.max_pos)

    def __selected_index(self, t):
        if t['owner'].value is None or t['name'] != 'value':
            return
        self.pos = t['new'] - 1
        self.__perform_action(self.pos, self.max_pos)

    def start_iteration(self):
        if self.max_pos < self.pos:
            print("No available images")
            return

        display(self.all_widgets)
        self.__perform_action(self.pos, self.max_pos)
class AnnotatorInterface(metaclass=abc.ABCMeta):

    def __init__(self,
                 dataset,
                 properties_and_values,  # properties { name: (MetaPropType, [possible values], optional label)}
                 output_path=None,
                 show_name=False,
                 show_axis=False,
                 fig_size=(10, 10),
                 buttons_vertical=False,
                 ds_name=None,
                 custom_display_function=None,
                 classes_to_annotate=None,
                 validate_function=None,
                 show_reset=True
                 ):

        for k, v in properties_and_values.items():
            if v[0] not in [MetaPropertiesType.UNIQUE, MetaPropertiesType.TEXT]:
                raise NotImplementedError(f"Cannot use {v[0]}!")

        self.dataset_orig = dataset
        self.properties_and_values = properties_and_values
        self.show_axis = show_axis
        self.show_reset = show_reset
        self.show_name = show_name

        if classes_to_annotate is None:  # if classes_to_annotate is None, all the classes would be annotated
            self.classes_to_annotate = self.dataset_orig.get_categories_names()  # otherwise, the only the classes in the list

        if ds_name is None:
            self.name = (self.dataset_orig.dataset_root_param.split('/')[-1]).split('.')[0]  # get_original_file_name
        else:
            self.name = ds_name

        self.set_output(output_path)

        print("{} {}".format(labels_str.info_new_ds, self.file_path_for_json))

        self.current_pos = 0
        self.mapping, self.dataset_annotated = self.create_results_dict(self.file_path_for_json)
        self.set_objects()

        self.fig_size = fig_size
        self.buttons_vertical = buttons_vertical

        self.current_image = {}

        label_total = self.create_label_total()

        # create buttons
        buttons = self.create_buttons()

        self.updated = False

        self.validation_show = HTML(value="")
        self.out = Output()
        self.out.add_class("my_canvas_class")

        self.checkboxes = {}
        self.radiobuttons = {}
        self.bounded_text = {}
        self.box_text = {}

        labels = self.create_check_radio_boxes()

        self.validate = not validate_function is None
        if self.validate:
            self.validate_function = validate_function

        self.set_display_function(custom_display_function)

        output_layout = self.set_check_radio_boxes_layout(labels)

        self.all_widgets = VBox(children=
                                [HBox([self.text_index, label_total]),
                                 HBox(buttons),
                                 self.validation_show,
                                 HBox([self.out,
                                       VBox(output_layout)])])

        self.load_js()

    @abc.abstractmethod
    def set_objects(self):
        pass

    @abc.abstractmethod
    def set_display_function(self, custom_display_function):
        pass

    def set_output(self, output_path):
        if output_path is None:
            name = self.dataset_orig.dataset_root_param
            self.output_directory = name.replace(os.path.basename(name), "")
        else:
            self.output_directory = output_path

        self.file_path_for_json = os.path.join(self.output_directory, self.name + "_ANNOTATED.json")

    def create_label_total(self):
        label_total = Label(value='/ {}'.format(len(self.objects)))
        self.text_index = BoundedIntText(value=1, min=1, max=len(self.objects))
        self.text_index.layout.width = '80px'
        self.text_index.layout.height = '35px'
        self.text_index.observe(self.selected_index)
        return label_total

    def create_buttons(self):
        # create buttons
        self.previous_button = self.create_button(labels_str.str_btn_prev, (self.current_pos == 0),
                                                  self.on_previous_clicked)
        self.next_button = self.create_button(labels_str.str_btn_next, (self.current_pos == self.max_pos),
                                              self.on_next_clicked)
        self.save_button = self.create_button(labels_str.str_btn_download, False, self.on_save_clicked)
        self.save_function = self.save_function  # save_function

        if self.show_reset:
            self.reset_button = self.create_button(labels_str.str_btn_reset, False, self.on_reset_clicked)
            buttons = [self.previous_button, self.next_button, self.reset_button, self.save_button]
        else:
            buttons = [self.previous_button, self.next_button, self.save_button]
        return buttons

    def create_check_radio_boxes(self):
        labels = dict()
        for k_name, v in self.properties_and_values.items():

            if len(v) == 3:
                label = v[2]
            elif MetaPropertiesType.TEXT.value == v[0].value and len(v) == 2:
                label = v[1]
            else:
                label = k_name

            labels[k_name] = label

            if MetaPropertiesType.UNIQUE.value == v[0].value:  # radiobutton
                self.radiobuttons[k_name] = RadioButtons(name=k_name, options=v[1],
                                                         disabled=False,
                                                         indent=False)
            elif MetaPropertiesType.COMPOUND.value == v[0].value:  # checkbox
                self.checkboxes[k_name] = [Checkbox(False, indent=False, name=k_name,
                                                    description=prop_name) for prop_name in v[1]]
            elif MetaPropertiesType.CONTINUE.value == v[0].value:
                self.bounded_text[k_name] = BoundedFloatText(value=v[1][0], min=v[1][0], max=v[1][1])

            elif MetaPropertiesType.TEXT.value == v[0].value:
                self.box_text[k_name] = Textarea(disabled=False)

        return labels

    def set_check_radio_boxes_layout(self, labels):
        output_layout = []
        self.check_radio_boxes_layout = {}
        for rb_k, rb_v in self.radiobuttons.items():
            rb_v.layout.width = '180px'
            rb_v.observe(self.checkbox_changed)
            rb_v.add_class(rb_k)
            html_title = HTML(value="<b>" + labels[rb_k] + "</b>")
            self.check_radio_boxes_layout[rb_k] = VBox([rb_v])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[rb_k]]))

        for cb_k, cb_i in self.checkboxes.items():
            for cb in cb_i:
                cb.layout.width = '180px'
                cb.observe(self.checkbox_changed)
                cb.add_class(cb_k)
            html_title = HTML(value="<b>" + labels[cb_k] + "</b>")
            self.check_radio_boxes_layout[cb_k] = VBox(children=[cb for cb in cb_i])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[cb_k]]))

        for bf_k, bf in self.bounded_text.items():
            bf.layout.width = '80px'
            bf.layout.height = '35px'
            bf.observe(self.checkbox_changed)
            bf.add_class(bf_k)
            html_title = HTML(value="<b>" + labels[bf_k] + "</b>")
            self.check_radio_boxes_layout[bf_k] = VBox([bf])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[bf_k]]))

        for tb_k, tb_i in self.box_text.items():
            tb_i.layout.width = '500px'
            tb_i.observe(self.checkbox_changed)
            tb_i.add_class(tb_k)
            html_title = HTML(value="<b>" + labels[tb_k] + "</b>")
            self.check_radio_boxes_layout[tb_k] = VBox([tb_i])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[tb_k]]))

        return output_layout

    def load_js(self):
        ## loading js library to perform html screenshots
        j_code = """
                        require.config({
                            paths: {
                                html2canvas: "https://html2canvas.hertzen.com/dist/html2canvas.min"
                            }
                        });
                    """
        display(Javascript(j_code))

    def change_check_radio_boxes_value(self, current_ann):
        for m_k, m_v in self.properties_and_values.items():
            if m_v[0].value == MetaPropertiesType.UNIQUE.value:  # radiobutton
                self.radiobuttons[m_k].unobserve(self.checkbox_changed)
                self.radiobuttons[m_k].value = current_ann[m_k] if m_k in current_ann.keys() else None
                self.radiobuttons[m_k].observe(self.checkbox_changed)
            elif m_v[0].value == MetaPropertiesType.COMPOUND.value:  # checkbox
                for cb_i, cb_v in enumerate(self.checkboxes[m_k]):
                    cb_v.unobserve(self.checkbox_changed)
                    if m_k in current_ann.keys():
                        if cb_v.description in current_ann[m_k].keys():
                            cb_v.value = current_ann[m_k][cb_v.description]
                        else:
                            cb_v.value = False
                    else:
                        cb_v.value = False
                    cb_v.observe(self.checkbox_changed)
            elif m_v[0].value == MetaPropertiesType.CONTINUE.value:  # textbound
                self.bounded_text[m_k].unobserve(self.checkbox_changed)
                self.bounded_text[m_k].value = float(current_ann[m_k]) if m_k in current_ann.keys() else \
                    self.bounded_text[m_k].min
                self.bounded_text[m_k].observe(self.checkbox_changed)
            elif m_v[0].value == MetaPropertiesType.TEXT.value:  # text
                self.box_text[m_k].unobserve(self.checkbox_changed)
                self.box_text[m_k].value = current_ann[m_k] if m_k in current_ann.keys() else ""
                self.box_text[m_k].observe(self.checkbox_changed)

    def create_results_dict(self, file_path):
        mapping = {}
        mapping["annotated_ids"] = set()  # key: object_id from dataset, values=[(annotations done)]
        mapping["categories_counter"] = dict.fromkeys([c for c in self.dataset_orig.get_categories_names()], 0)
        self.mapping = mapping

        if not os.path.exists(file_path):  # it does exist __ANNOTATED in the output directory
            with open(self.dataset_orig.dataset_root_param, 'r') as input_json_file:
                dataset_annotated = json.load(input_json_file)
                input_json_file.close()

            # take the same metaproperties already in file if it's not empty (like a new file)
            meta_prop = dataset_annotated['meta_properties'] if 'meta_properties' in dataset_annotated.keys() else []

            # adds the new annotations categories to dataset if it doesn't exist
            for k_name, v in self.properties_and_values.items():

                new_mp_to_append = {
                    "name": k_name,
                    "type": v[0].value,
                }

                if len(v) > 1:
                    new_mp_to_append["values"] = sorted([p for p in v[1]])

                names_prop_in_file = {m_p['name']: m_p for m_i, m_p in enumerate(dataset_annotated[
                                                                                     'meta_properties'])} if 'meta_properties' in dataset_annotated.keys() else None

                if 'meta_properties' not in dataset_annotated.keys():  # it is a new file
                    meta_prop.append(new_mp_to_append)
                    dataset_annotated['meta_properties'] = []

                elif names_prop_in_file is not None and k_name not in names_prop_in_file.keys():
                    # if there is a property with the same in meta_properties, it must be the same structure as the one proposed
                    meta_prop.append(new_mp_to_append)
                    self.update_annotation_counter_and_current_pos(dataset_annotated)

                elif names_prop_in_file is not None and k_name in names_prop_in_file.keys() and \
                        names_prop_in_file[k_name] == new_mp_to_append:
                    # we don't append because it's already there
                    self.update_annotation_counter_and_current_pos(dataset_annotated)

                else:
                    raise NameError("An annotation with the same name {} "
                                    "already exist in dataset {}, and it has different structure. Check properties.".format(
                        k_name, self.dataset_orig.dataset_root_param))

                # if k_name is in name_props_in_file and it's the same structure. No update is done.
            dataset_annotated['meta_properties'] = dataset_annotated['meta_properties'] + meta_prop
        else:
            with open(file_path, 'r') as classification_file:
                dataset_annotated = json.load(classification_file)
                classification_file.close()

            self.update_annotation_counter_and_current_pos(dataset_annotated)
        return mapping, dataset_annotated

    @abc.abstractmethod
    def add_annotation_to_mapping(self, ann):
        pass

    @abc.abstractmethod
    def update_mapping_from_whole_dataset(self):
        pass

    @abc.abstractmethod
    def update_annotation_counter_and_current_pos(self, dataset_annotated):
        pass

    def execute_validation(self, ann):
        if self.validate:
            if self.validate_function(ann):
                self.validation_show.value = labels_str.srt_validation_not_ok
            else:
                self.validation_show.value = labels_str.srt_validation_ok

    @abc.abstractmethod
    def show_name_func(self, image_record, path_img):
        pass

    @abc.abstractmethod
    def checkbox_changed(self, b):
        pass

    def create_button(self, description, disabled, function):
        button = Button(description=description)
        button.disabled = disabled
        button.on_click(function)
        return button

    @abc.abstractmethod
    def show_image(self, image_record, ann_key):
        pass

    @abc.abstractmethod
    def save_state(self):
        pass

    def save_function(self, image_path):
        img_name = os.path.basename(image_path).split('.')[0]
        j_code = """
            require(["html2canvas"], function(html2canvas) {
                var element = $(".p-Widget.jupyter-widgets-output-area.output_wrapper.$it_name$")[0];
                console.log(element);
                 html2canvas(element).then(function (canvas) { 
                    var myImage = canvas.toDataURL(); 
                    var a = document.createElement("a"); 
                    a.href = myImage; 
                    a.download = "$img_name$.png"; 
                    a.click(); 
                    a.remove(); 
                });
            });
            """
        j_code = j_code.replace('$it_name$', "my_canvas_class")
        j_code = j_code.replace('$img_name$', img_name)
        tmp_out = Output()
        with tmp_out:
            display(Javascript(j_code))
            tmp_out.clear_output()

    @abc.abstractmethod
    def perform_action(self):
        pass

    @abc.abstractmethod
    def get_image_record(self):
        pass

    @abc.abstractmethod
    def on_reset_clicked(self, b):
        pass

    def on_previous_clicked(self, b):
        self.save_state()
        self.current_pos -= 1
        self.perform_action()

    def on_next_clicked(self, b):
        self.save_state()
        self.current_pos += 1
        self.perform_action()

    @abc.abstractmethod
    def on_save_clicked(self, b):
        pass

    def selected_index(self, t):
        if t['owner'].value is None or t['name'] != 'value':
            return
        self.current_pos = t['new'] - 1
        self.perform_action()

    def start_annotation(self):
        if self.max_pos < self.current_pos:
            print(labels_str.info_no_more_images)
            return
        display(self.all_widgets)
        self.perform_action()

    @abc.abstractmethod
    def print_statistics(self):
        pass

    @abc.abstractmethod
    def print_results(self):
        pass
示例#6
0
class DatasetAnnotatorClassification:
    supported_types = [
        TaskType.CLASSIFICATION_BINARY, TaskType.CLASSIFICATION_SINGLE_LABEL,
        TaskType.CLASSIFICATION_MULTI_LABEL
    ]

    def __init__(self,
                 task_type,
                 observations,
                 output_path,
                 name,
                 classes,
                 show_name=True,
                 show_axis=False,
                 fig_size=(10, 10),
                 buttons_vertical=False,
                 custom_display_function=None,
                 is_image=True):

        if task_type not in self.supported_types:
            raise Exception(labels_str.warn_task_not_supported)

        if len(observations) == 0:
            raise Exception(labels_str.warn_no_images)

        num_classes = len(classes)
        if num_classes <= 1:
            raise Exception(labels_str.warn_little_classes)

        elif len(
                classes
        ) > 2 and task_type.value == TaskType.CLASSIFICATION_BINARY.value:
            raise Exception(labels_str.warn_binary_only_two)

        self.is_image = is_image
        self.key = "path" if self.is_image else "observation"
        if not self.is_image and custom_display_function is None:
            raise Exception(labels_str.warn_display_function_needed)

        self.task_type = task_type
        self.show_axis = show_axis
        self.name = name
        self.show_name = show_name
        self.output_path = output_path
        self.file_path = os.path.join(self.output_path, self.name + ".json")
        print(labels_str.info_ds_output + self.file_path)
        self.mapping, self.dataset = self.__create_results_dict(
            self.file_path, classes)

        self.classes = list(self.mapping["categories_id"].values())

        if len(
                self.classes
        ) > 2 and task_type.value == TaskType.CLASSIFICATION_BINARY.value:
            raise Exception(labels_str.warn_binary_only_two +
                            " ".join(self.classes))

        self.observations = observations
        self.max_pos = len(self.observations) - 1
        self.pos = 0
        self.fig_size = fig_size
        self.buttons_vertical = buttons_vertical

        if custom_display_function is None:
            self.image_display_function = self.__show_image
        else:
            self.image_display_function = custom_display_function

        self.previous_button = self.__create_button(labels_str.str_btn_prev,
                                                    (self.pos == 0),
                                                    self.__on_previous_clicked)
        self.next_button = self.__create_button(labels_str.str_btn_next,
                                                (self.pos == self.max_pos),
                                                self.__on_next_clicked)
        self.save_button = self.__create_button(labels_str.str_btn_download,
                                                False, self.__on_save_clicked)
        self.save_function = self.__save_function  # save_function

        buttons = [self.previous_button, self.next_button, self.save_button]

        label_total = Label(value='/ {}'.format(len(self.observations)))
        self.text_index = BoundedIntText(value=1,
                                         min=1,
                                         max=len(self.observations))
        self.text_index.layout.width = '80px'
        self.text_index.layout.height = '35px'
        self.text_index.observe(self.__selected_index)
        self.out = Output()
        self.out.add_class(name)

        if self.__is_multilabel():
            self.checkboxes = [
                Checkbox(False,
                         description='{}'.format(self.classes[i]),
                         indent=False) for i in range(len(self.classes))
            ]
            for cb in self.checkboxes:
                cb.layout.width = '180px'
                cb.observe(self.__checkbox_changed)
            self.checkboxes_layout = VBox(
                children=[cb for cb in self.checkboxes])
        else:
            self.checkboxes = RadioButtons(options=self.classes,
                                           disabled=False,
                                           indent=False)
            self.checkboxes.layout.width = '180px'
            self.checkboxes.observe(self.__checkbox_changed)
            self.checkboxes_layout = VBox(children=[self.checkboxes])

        output_layout = HBox(children=[self.out, self.checkboxes_layout])
        if self.buttons_vertical:
            self.all_widgets = HBox(children=[
                VBox(children=[HBox([self.text_index, label_total])] +
                     buttons), output_layout
            ])
        else:
            self.all_widgets = VBox(children=[
                HBox([self.text_index, label_total]),
                HBox(children=buttons), output_layout
            ])

        ## loading js library to perform html screenshots
        j_code = """
                require.config({
                    paths: {
                        html2canvas: "https://html2canvas.hertzen.com/dist/html2canvas.min"
                    }
                });
            """
        display(Javascript(j_code))

    def __create_results_dict(self, file_path, cc):
        mapping = {
            "categories_id": {},
            "categories_name": {},
            "observations": {}
        }

        if not os.path.exists(file_path):
            dataset = {'categories': [], "observations": []}
            for index, c in enumerate(cc):
                category = {"supercategory": c, "name": c, "id": index + 1}
                dataset["categories"].append(category)
        else:
            with open(file_path, 'r') as classification_file:
                dataset = json.load(classification_file)
            for index, img in enumerate(dataset['observations']):
                mapping['observations'][img[self.key]] = index

        for c in dataset['categories']:
            mapping['categories_id'][c["id"]] = c["name"]
            mapping['categories_name'][c["name"]] = c["id"]
        index_categories = len(dataset['categories']) + 1

        for c in cc:
            if not c in mapping['categories_name'].keys():
                mapping['categories_id'][index_categories] = c
                mapping['categories_name'][c] = index_categories
                category = {
                    "supercategory": c,
                    "name": c,
                    "id": index_categories
                }
                dataset["categories"].append(category)
                index_categories += 1

        return mapping, dataset

    def __checkbox_changed(self, b):

        if b['owner'].value is None or b['name'] != 'value':
            return

        class_name = b['owner'].description
        value = b['owner'].value
        current_index = self.mapping["observations"][self.observations[
            self.pos]]

        if self.__is_multilabel():
            class_index = self.mapping["categories_name"][class_name]
            if not class_index in self.dataset["observations"][current_index][
                    "categories"] and value:
                self.dataset["observations"][current_index][
                    "categories"].append(class_index)
            if class_index in self.dataset["observations"][current_index][
                    "categories"] and not value:
                self.dataset["observations"][current_index][
                    "categories"].remove(class_index)
        else:
            class_index = self.mapping["categories_name"][value]
            self.dataset["observations"][current_index][
                "category"] = class_index

        if self.pos == self.max_pos:
            self.save_state()

    def __is_multilabel(self):
        return TaskType.CLASSIFICATION_MULTI_LABEL.value == self.task_type.value

    def __create_button(self, description, disabled, function):
        button = Button(description=description)
        button.disabled = disabled
        button.on_click(function)
        return button

    def __show_image(self, image_record):
        if not 'path' in image_record:
            print("missing path")
        if not os.path.exists(image_record['path']):
            print("Image cannot be load" + image_record['path'])

        img = Image.open(image_record['path'])
        if self.show_name:
            print(os.path.basename(image_record['path']))
        plt.figure(figsize=self.fig_size)
        if not self.show_axis:
            plt.axis('off')
        plt.imshow(img)
        plt.show()

    def save_state(self):
        with open(self.file_path, 'w') as output_file:
            json.dump(self.dataset, output_file, indent=4)

    def __save_function(self, image_path, index):
        img_name = os.path.basename(image_path).split('.')[0]
        j_code = """
            require(["html2canvas"], function(html2canvas) {
                var element = $(".p-Widget.jupyter-widgets-output-area.output_wrapper.$it_name$")[0];
                console.log(element);
                 html2canvas(element).then(function (canvas) { 
                    var myImage = canvas.toDataURL(); 
                    var a = document.createElement("a"); 
                    a.href = myImage; 
                    a.download = "$img_name$.png"; 
                    a.click(); 
                    a.remove(); 
                });
            });
            """
        j_code = j_code.replace('$it_name$', self.name)
        j_code = j_code.replace('$img_name$', img_name)
        tmp_out = Output()
        with tmp_out:
            display(Javascript(j_code))
            tmp_out.clear_output()

    def __perform_action(self):

        if not self.observations[
                self.pos] in self.mapping["observations"].keys():
            observation = {self.key: self.observations[self.pos]}
            if self.is_image:
                observation["file_name"] = os.path.basename(
                    self.observations[self.pos])

            observation["id"] = len(self.mapping["observations"]) + 1
            if self.__is_multilabel():
                observation["categories"] = []
            self.dataset["observations"].append(observation)

            self.mapping["observations"][observation[self.key]] = len(
                self.dataset["observations"]) - 1

        current_index = self.mapping["observations"][self.observations[
            self.pos]]
        self.next_button.disabled = (self.pos == self.max_pos)
        self.previous_button.disabled = (self.pos == 0)

        if self.__is_multilabel():
            for cb in self.checkboxes:
                cb.unobserve(self.__checkbox_changed)
                if not "categories" in self.dataset["observations"][
                        current_index]:
                    self.dataset["observations"][current_index][
                        "categories"] = []
                categories = self.dataset["observations"][current_index][
                    "categories"]
                cb.value = self.mapping["categories_name"][
                    cb.description] in categories
                cb.observe(self.__checkbox_changed)
        else:
            self.checkboxes.unobserve(self.__checkbox_changed)
            obs = self.dataset["observations"][current_index]
            category = obs["category"] if "category" in obs else None

            if category:
                for k in self.mapping["categories_name"]:
                    if self.mapping["categories_name"][k] == category:
                        category = k
                        break
            self.checkboxes.value = category
            self.checkboxes.observe(self.__checkbox_changed)

        with self.out:
            self.out.clear_output()
            self.image_display_function(
                self.dataset["observations"][current_index])

        self.text_index.unobserve(self.__selected_index)
        self.text_index.value = self.pos + 1
        self.text_index.observe(self.__selected_index)

    def __on_previous_clicked(self, b):
        self.save_state()
        self.pos -= 1
        self.__perform_action()

    def __on_next_clicked(self, b):
        self.save_state()
        self.pos += 1
        self.__perform_action()

    def __on_save_clicked(self, b):
        self.save_state()
        self.save_function(self.observations[self.pos], self.pos)

    def __selected_index(self, t):
        if t['owner'].value is None or t['name'] != 'value':
            return
        self.pos = t['new'] - 1
        self.__perform_action()

    def start_classification(self):
        if self.max_pos < self.pos:
            print("No available observation")
            return
        display(self.all_widgets)
        self.__perform_action()

    def print_statistics(self):
        counter = dict()
        for c in self.mapping["categories_id"]:
            counter[c] = 0

        for record in self.dataset["observations"]:

            if self.__is_multilabel():
                if "categories" in record:
                    for c in record["categories"]:
                        counter[c] += 1
            elif "category" in record:
                counter[record["category"]] += 1

        table = []
        for c in counter:
            table.append([self.mapping["categories_id"][c], counter[c]])
        table = sorted(table, key=lambda x: x[0])

        print(
            tabulate(table,
                     headers=[
                         labels_str.info_class_name, labels_str.info_ann_images
                     ]))
示例#7
0
class MonitorV2:
    _data: list
    _op_lim_max = None
    _op_lim_min = None
    _op_tol_max = None
    _op_tol_min = None
    _y_unit: str
    _x_unit: str
    _max_plots: int
    _x_pre_tick: timedelta
    _x_post_tick: timedelta
    _plot_trigger: Checkbox
    _graph: i_active
    _curr: i_active
    _out: Output
    _group: VBox
    _graph_layout: go.Layout

    def __init__(self,
                 title=DEF_TITLE,
                 y_unit="%",
                 x_unit="utc",
                 max_plots=5,
                 data=None,
                 y_min=DEF_Y_MIN,
                 y_max=DEF_Y_MAX,
                 x_pre_tick=None,
                 x_post_tick=None,
                 op_enabled=False,
                 op_lim_min=None,
                 op_lim_max=None,
                 op_tol_min=None,
                 op_tol_max=None):

        ti, oe, xu, yu, pt = (title, op_enabled, x_unit, y_unit,
                              Checkbox(value=False))
        pt.layout.visibility = 'hidden'
        yl, yh = (DEF_Y_MIN if y_min is None else y_min,
                  DEF_Y_MAX if y_max is None else y_max)
        mp, oll, olh = (max_plots, op_lim_min if oe else None,
                        op_lim_max if oe else None)
        to = 0 if not oe else olh - oll if oll and olh else yh - oll if olh is None else olh - yl if oll is None else 0
        oth = MonitorV2._get_tolerance(oe, op_tol_max, olh, to)
        otl = MonitorV2._get_tolerance(oe, op_tol_min, oll, -to)
        ic = i_active(self.show_current,
                      t=pt,
                      e=fix(oe),
                      l=fix(oll),
                      h=fix(olh),
                      i=fix(otl),
                      j=fix(oth))
        ig = i_active(self.show_graph,
                      t=pt,
                      e=fix(oe),
                      l=fix(oll),
                      h=fix(olh),
                      i=fix(otl),
                      j=fix(oth))
        self._out = Output(layout=OUT_LAYOUT)
        self._out.add_class("mon-out")
        og = ig.children[-1]
        og.layout = GRA_OUT_LAYOUT
        d_oll = f"{oll:.2f}" if oe and oll is not None else "None"
        d_olh = f"{olh:.2f}" if oe and olh is not None else "None"
        d_otl = f"{otl:.2f}" if oe and otl is not None else "None"
        d_oth = f"{oth:.2f}" if oe and oth is not None else "None"
        vb = VBox([
            HTML(OUT_STYLE),
            Label(layout=Layout(width='350px')),
            HTML(f"<h2 style='{TIT_MARGIN}'>{ti}</h2>"),
            HTML(
                f"<h3 style='{SUB_TIT_MARGIN}'>{TXT_OP_RANGE} {'(on)' if oe else '(off)'}</h3>"
            ),
            HBox([
                HTML(f"<h4 style='{LBL_MARGIN}'>{TXT_LIM}</h4>"),
                Label(
                    f"min: {d_oll}, max: {d_olh} {yu if oll is not None or d_olh is not None else ''}"
                )
            ]),
            HBox([
                HTML(f"<h4 style='{LBL_MARGIN}'>{TXT_TOL}:</h4>"),
                Label(
                    f"min: {d_otl}, max: {d_oth} {yu if d_otl is not None or d_oth is not None else ''}"
                )
            ]),
            HTML(f"<h3 style='{SUB_TIT_MARGIN}'>{TXT_CURRENT}</h3>"), ic
        ])
        hb = HBox([vb, ig])
        self._group = VBox([hb])
        self._graph_layout = GRA_LAYOUT
        self._graph_layout.update(dict(xaxis_title=xu, yaxis_title=yu))
        self._graph_layout.yaxis.update(range=[yl, yh])
        self._graph_layout.update(xaxis_tickformat="%H:%M:%S")
        self._x_pre_tick = x_pre_tick if x_pre_tick is not None else DEF_TICK_TIME_DELTA
        self._x_post_tick = x_post_tick if x_post_tick is not None else DEF_TICK_TIME_DELTA
        self._data = [[], []]
        if data is not None:
            dx = data[0]
            dy = data[1]
            self._data = [dx, dy]
        self._x_unit, self._y_unit, self._plot_trigger, self._max_plots, self._graph = (
            xu, yu, pt, mp, ig)
        self._op_lim_min, self._op_lim_max, self._op_tol_min, self._op_tol_max, self._Curr = (
            oll, olh, otl, oth, ic)

    @staticmethod
    def _get_tolerance(enabled, t_v, l_v, total):
        if not enabled or t_v is None or l_v is None or not total: return None
        return l_v - (total * float(str(t_v).replace("%", "")) /
                      100.0 if str(t_v).endswith("%") else l_v - float(t_v))

    def show_current(self, t, e, l, h, i, j):
        if t is None: return
        vd = self._data
        if not vd[1]: return
        xv, yv, vs, vd, vt, vl = (vd[0][len(vd[0]) - 1], vd[1][len(vd[1]) - 1],
                                  "v-sig", "v-def", "v-tol", "v-lim")
        cls = vs if not e else vl if h is not None and yv > h or l is not None and yv < l else vt if \
            i is not None and yv < i or j is not None and yv > j else vd
        v = f"<div class='mon-val-y {cls}'><span>{yv:.3f} {self._y_unit}</span></div>"
        t1 = str(xv.isoformat(timespec='seconds')).replace('T', ' ')
        td = dt.now() - dt.utcnow()
        t2 = str((xv + timedelta(seconds=td.seconds)).isoformat(
            timespec='seconds')).replace('T', ' ')
        ts = f"<div class='mon-val-x'><span class='bold'>{self._x_unit.capitalize()}</span><span>: {t1}</span></div>"
        tl = f"<div class='mon-val-x2'><span class='bold'>{TXT_LOCAL.capitalize()}</span><span>: {t2}</span></div>"
        display(HTML(v + ts + tl))

    def show_graph(self, t, e, l, h, i, j):
        if t is None: return
        gd = self._data
        x_data = gd[0][len(gd[0]) - self._max_plots:]
        y_data = gd[1][len(gd[1]) - self._max_plots:] if gd[1] and len(
            gd[1]) > self._max_plots else gd[1]
        config = {"displayModeBar": False}
        fig = go.Figure(layout=self._graph_layout)
        fig.update_yaxes(nticks=10)
        x_min = x_data[0] - self._x_pre_tick
        x_max = x_data[len(x_data) - 1] + self._x_post_tick
        fig.update_xaxes(range=[x_min, x_max])
        if y_data and len(y_data):
            if e:
                opc = 0.6
                if h:
                    x_l_max, y_l_max, tlh, llh = ([x_min, x_max], [h, h],
                                                  TXT_LIM_MAX, LIN_LIM_MAX)
                    fig.add_trace(
                        go.Scatter(x=x_l_max,
                                   y=y_l_max,
                                   mode="lines",
                                   name=tlh,
                                   line=llh,
                                   opacity=opc))
                if l:
                    x_l_min, y_l_min, tll, lll = ([x_min, x_max], [l, l],
                                                  TXT_LIM_MIN, LIN_LIM_MIN)
                    fig.add_trace(
                        go.Scatter(x=x_l_min,
                                   y=y_l_min,
                                   mode="lines",
                                   name=tll,
                                   line=lll,
                                   opacity=opc))
                if j:
                    x_t_max, y_t_max, tth, lth = ([x_min, x_max], [j, j],
                                                  TXT_TOL_MAX, LIN_TOL_MAX)
                    fig.add_trace(
                        go.Scatter(x=x_t_max,
                                   y=y_t_max,
                                   mode="lines",
                                   name=tth,
                                   line=lth,
                                   opacity=opc))
                if i:
                    x_t_min, y_t_min, ttl, ltl = ([x_min, x_max], [i, i],
                                                  TXT_TOL_MIN, LIN_TOL_MIN)
                    fig.add_trace(
                        go.Scatter(x=x_t_min,
                                   y=y_t_min,
                                   mode="lines",
                                   name=ttl,
                                   line=ltl,
                                   opacity=opc))
            y_len = len(y_data)
            x_plot = x_data[len(x_data) - y_len:]
            fig.add_trace(
                go.Scatter(x=x_plot,
                           y=y_data,
                           mode="markers",
                           name=TXT_SIGNAL,
                           marker=MRK_SIGNAL,
                           opacity=0.6))
            x_cur, y_cur = ([x_data[len(x_data) - 1]],
                            [y_data[len(y_data) - 1]])
            yv, cs, cd, ct, cl = (y_cur[0], CLR_SIGNAL, CLR_OP_DEF, CLR_OP_TOL,
                                  CLR_OP_LIM)
            m_clr = cs if not e else cl if h is not None and yv > h or l is not None and yv < l else ct if \
                i is not None and yv < i or j is not None and yv > j else cd
            m_cur = dict(color=m_clr, size=12, symbol=SYM_CURRENT)
            fig.add_trace(
                go.Scatter(x=x_cur,
                           y=y_cur,
                           mode="markers",
                           name=TXT_CURRENT,
                           opacity=1.0,
                           marker=m_cur))
        fig.show(config=config)

    def _plot_graph(self):
        self._plot_trigger.value = True if not self._plot_trigger.value else False

    def update(self, x, y):
        x_data, y_data = self._data
        x_data.append(x)
        y_data.append(y)
        self._plot_graph()

    def output(self):
        with self._out:
            display(self._group)
        return self._out