Beispiel #1
0
class MultiClassAnnotator:
    def __init__(self,
                 dataset_voc,
                 output_path_statistic,
                 name,
                 metrics,  # this is an array of the following ['occ', 'truncated', 'side', 'part']
                 show_name=True,
                 show_axis=False,
                 fig_size=(10, 10),
                 buttons_vertical=False,
                 image_display_function=None,
                 classes_to_annotate=None
                 ):

        if dataset_voc.annotations_gt is None:  # in case that dataset_voc has not been called
            dataset_voc.load()

        self.dataset_voc = dataset_voc
        self.metrics = metrics
        self.show_axis = show_axis
        self.name = name
        self.show_name = show_name
        if output_path_statistic is None:
            output_path_statistic = self.dataset_voc.dataset_root_param

        if classes_to_annotate is None:  # if classes_to_annotate is None, all the classes would be annotated
            self.classes_to_annotate = self.dataset_voc.objnames_all  # otherwise, the only the classes in the list

        self.output_path = output_path_statistic
        self.file_path = os.path.join(self.output_path, self.name + ".json")
        self.mapping, self.dataset = self.__create_results_dict(self.file_path, metrics)

        self.objects = dataset_voc.get_objects_index(self.classes_to_annotate)
        self.current_pos = 0

        self.max_pos = len(self.objects) - 1

        self.fig_size = fig_size
        self.buttons_vertical = buttons_vertical

        if image_display_function is None:
            self.image_display_function = self.__show_image
        else:
            self.image_display_function = image_display_function

        # create buttons
        self.previous_button = self.__create_button("Previous", (self.current_pos == 0), self.__on_previous_clicked)
        self.next_button = self.__create_button("Next", (self.current_pos == self.max_pos), self.__on_next_clicked)
        self.save_button = self.__create_button("Save", False, self.__on_save_clicked)
        self.save_function = self.__save_function  # save_function
        self.current_image = {}
        buttons = [self.previous_button, self.next_button]
        buttons.append(self.save_button)

        label_total = Label(value='/ {}'.format(len(self.objects)))
        self.text_index = BoundedIntText(value=1, min=1, max=len(self.objects))
        self.text_index.layout.width = '80px'
        self.text_index.layout.height = '35px'
        self.text_index.observe(self.__selected_index)
        self.out = Output()
        self.out.add_class(name)

        metrics_labels = self.dataset_voc.read_label_metrics_name_from_file()
        metrics_labels['truncated'] = {'0': 'False', '1': 'True'}
        self.metrics_labels = metrics_labels

        self.checkboxes = {}
        self.radiobuttons = {}

        output_layout = []
        for m_i, m_n in enumerate(self.metrics):
            if 'parts' == m_n:  # Special case
                continue

            if m_n in ['truncated', 'occ']:  # radiobutton
                self.radiobuttons[m_n] = RadioButtons(options=[i for i in metrics_labels[m_n].values()],
                                                      disabled=False,
                                                      indent=False)
            else:  # checkbox
                self.checkboxes[m_n] = [Checkbox(False, description='{}'.format(metrics_labels[m_n][i]),
                                                 indent=False) for i in metrics_labels[m_n].keys()]

        self.check_radio_boxes_layout = {}
        for cb_k, cb_i in self.checkboxes.items():
            for cb in cb_i:
                cb.layout.width = '180px'
                cb.observe(self.__checkbox_changed)
            html_title = HTML(value="<b>" + cb_k + "</b>")
            self.check_radio_boxes_layout[cb_k] = VBox(children=[cb for cb in cb_i])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[cb_k]]))

        for rb_k, rb_v in self.radiobuttons.items():
            rb_v.layout.width = '180px'
            rb_v.observe(self.__checkbox_changed)
            html_title = HTML(value="<b>" + rb_k + "</b>")
            self.check_radio_boxes_layout[rb_k] = VBox([rb_v])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[rb_k]]))

        #create an output for the future dynamic SIDES_PARTS attributes
        self.dynamic_output_for_parts = Output()
        html_title = HTML(value="<b>" + "Parts" + "</b>")
        output_layout.append(VBox([html_title, self.dynamic_output_for_parts]))

        self.all_widgets = VBox(children=
                                [HBox([self.text_index, label_total]),
                                 HBox(buttons),
                                 HBox(output_layout),
                                 self.out])

        ## loading js library to perform html screenshots
        j_code = """
                require.config({
                    paths: {
                        html2canvas: "https://html2canvas.hertzen.com/dist/html2canvas.min"
                    }
                });
            """
        display(Javascript(j_code))

    def __create_results_dict(self, file_path, cc):
        mapping = {}
        # mapping["categories_id"] = {}
        # mapping["categories_name"] = {}
        # mapping["objects"] = {}
        #
        # if not os.path.exists(file_path):
        #     dataset = {}
        #     dataset['categories'] = []
        #     dataset["images"] = []
        #     for index, c in enumerate(cc):
        #         category = {}
        #         category["supercategory"] = c
        #         category["name"] = c
        #         category["id"] = index
        #         dataset["categories"].append(category)
        # else:
        #     with open(file_path, 'r') as classification_file:
        #         dataset = json.load(classification_file)
        #     for index, img in enumerate(dataset['images']):
        #         mapping['images'][img["path"]] = index
        #
        # for index, c in enumerate(dataset['categories']):
        #     mapping['categories_id'][c["id"]] = c["name"]
        #     mapping['categories_name'][c["name"]] = c["id"]
        # index_categories = len(dataset['categories']) - 1
        #
        # for c in cc:
        #     if not c in mapping['categories_name'].keys():
        #         mapping['categories_id'][index_categories] = c
        #         mapping['categories_name'][c] = index_categories
        #         category = {}
        #         category["supercategory"] = c
        #         category["name"] = c
        #         category["id"] = index_categories
        #         dataset["categories"].append(category)
        #         index_categories += 1

        return {}, {}

    def __checkbox_changed(self, b):
        if b['owner'].value is None or b['name'] != 'value':
            return

        class_name = b['owner'].description
        value = b['owner'].value

        current_index = self.mapping["images"][self.objects[self.current_pos]]
        class_index = self.mapping["categories_name"][class_name]
        if not class_index in self.dataset["images"][current_index]["categories"] and value:
            self.dataset["images"][current_index]["categories"].append(class_index)
        if class_index in self.dataset["images"][current_index]["categories"] and not value:
            self.dataset["images"][current_index]["categories"].remove(class_index)

    def __create_button(self, description, disabled, function):
        button = Button(description=description)
        button.disabled = disabled
        button.on_click(function)
        return button

    def __show_image(self, image_record, obj_num):
        #   read img from path and show it
        path_img = os.path.join(self.output_path, 'JPEGImages', image_record['filename'])
        img = Image.open(path_img)
        if self.show_name:
            print(os.path.basename(path_img) + '. Class: {} [class_id={}]'.format(
                self.objects[self.current_pos]['class_name'],
                self.objects[self.current_pos]['class_id']))
        plt.figure(figsize=self.fig_size)

        if not self.show_axis:
            plt.axis('off')
        plt.imshow(img)

        # draw the bbox from the object onum
        ax = plt.gca()
        class_colors = cm.rainbow(np.linspace(0, 1, len(self.dataset_voc.objnames_all)))

        [bbox_x1, bbox_y1, bbox_x2, bbox_y2] = image_record['objects'][obj_num]['bbox']
        poly = [[bbox_x1, bbox_y1], [bbox_x1, bbox_y2], [bbox_x2, bbox_y2],
                [bbox_x2, bbox_y1]]
        np_poly = np.array(poly).reshape((4, 2))

        object_class_name = image_record['objects'][obj_num]['class']
        c = class_colors[self.dataset_voc.objnames_all.index(object_class_name)]

        # draws the bbox
        ax.add_patch(
            Polygon(np_poly, linestyle='-', facecolor=(c[0], c[1], c[2], 0.0),
                    edgecolor=(c[0], c[1], c[2], 1.0), linewidth=2))

        #  write the class name in bbox
        ax.text(x=bbox_x1, y=bbox_y1, s=object_class_name, color='white', fontsize=9, horizontalalignment='left',
                verticalalignment='top',
                bbox=dict(facecolor=(c[0], c[1], c[2], 0.5)))
        plt.show()

    def save_state(self):
        with open(self.file_path, 'w') as output_file:
            json.dump(self.dataset, output_file, indent=4)

    def __save_function(self, image_path):
        img_name = os.path.basename(image_path).split('.')[0]
        j_code = """
            require(["html2canvas"], function(html2canvas) {
                var element = $(".p-Widget.jupyter-widgets-output-area.output_wrapper.$it_name$")[0];
                console.log(element);
                 html2canvas(element).then(function (canvas) { 
                    var myImage = canvas.toDataURL(); 
                    var a = document.createElement("a"); 
                    a.href = myImage; 
                    a.download = "$img_name$.png"; 
                    a.click(); 
                    a.remove(); 
                });
            });
            """
        j_code = j_code.replace('$it_name$', self.name)
        j_code = j_code.replace('$img_name$', img_name)
        tmp_out = Output()
        with tmp_out:
            display(Javascript(j_code))
            tmp_out.clear_output()

    def __perform_action(self):
        # generates statistic saved in json file
        # print(self.objects[self.current_class])
        # if not self.objects[self.current_class][self.current_pos]['path'] in self.mapping["images"].keys():
        #     image = {}
        #     image["path"] = self.objects[self.current_class][self.current_pos]
        #     image["id"] = len(self.mapping["images"]) + 1
        #     image["categories"] = []
        #     self.dataset["images"].append(image)
        #     self.mapping["images"][image["path"]] = len(self.dataset["images"]) - 1
        # current_index = self.mapping["images"][self.objects[self.current_class][self.current_pos]]
        self.next_button.disabled = (self.current_pos == self.max_pos)
        self.previous_button.disabled = (self.current_pos == 0)

        current_class_id = self.objects[self.current_pos]['class_id']
        current_gt_id = self.objects[self.current_pos]['gt_id']

        # start to check for each type of metric
        if 'occ' in self.radiobuttons.keys():
            cb = self.radiobuttons['occ']
            rb_options = self.radiobuttons['occ'].options
            cb.unobserve(self.__checkbox_changed)
            cb.value = None #clear the current value
            if self.dataset_voc.annotations_gt['gt'][current_class_id]['details'][
                current_gt_id]:  # check if it is empty
                occ_level = self.dataset_voc.annotations_gt['gt'][current_class_id]['details'][current_gt_id][
                    'occ_level']
                cb.value = rb_options[occ_level - 1]
            cb.observe(self.__checkbox_changed)

        if 'truncated' in self.radiobuttons.keys(): #since this works for PASCAL VOC there's always a truncation value
            cb = self.radiobuttons['truncated']
            rb_options = self.radiobuttons['truncated'].options
            cb.unobserve(self.__checkbox_changed)
            cb.value = rb_options[
                int(self.dataset_voc.annotations_gt['gt'][current_class_id]['istrunc'][current_gt_id] == True)]
            cb.observe(self.__checkbox_changed)

        if 'views' in self.checkboxes.keys():
            for cb_i, cb in enumerate(self.checkboxes['views']):
                cb.unobserve(self.__checkbox_changed)
                cb.value = False #clear the value
                if self.dataset_voc.annotations_gt['gt'][current_class_id]['details'][current_gt_id]:
                    # check if it is empty
                    cb.value = bool(self.dataset_voc.annotations_gt['gt'][current_class_id]['details'][current_gt_id][
                                        'side_visible'][cb.description])
                cb.observe(self.__checkbox_changed)

        #need to create the output first for the buttons
        with self.dynamic_output_for_parts:
            self.dynamic_output_for_parts.clear_output()
            if self.objects[self.current_pos]['class_name'] in self.metrics_labels['parts']:
                self.cb_parts = [Checkbox(False, description='{}'.format(i), indent=False) for i in self.metrics_labels['parts'][self.objects[self.current_pos]['class_name']]]
            else:
                self.cb_parts = [HTML(value="No PARTS defined in Conf file")]
            display(VBox(children=[cb for cb in self.cb_parts]))

        with self.out:
            self.out.clear_output()
            image_record, obj_num = self.__get_image_record()
            self.image_display_function(image_record, obj_num)

        self.text_index.unobserve(self.__selected_index)
        self.text_index.value = self.current_pos + 1
        self.text_index.observe(self.__selected_index)

    def __get_image_record(self):
        current_class_id = self.objects[self.current_pos]['class_id']
        current_gt_id = self.objects[self.current_pos]['gt_id']

        obj_num = self.dataset_voc.annotations_gt['gt'][current_class_id]['onum'][current_gt_id]
        index_row = self.dataset_voc.annotations_gt['gt'][current_class_id]['rnum'][current_gt_id]
        r = self.dataset_voc.annotations_gt['rec'][index_row]
        return r, obj_num

    def __on_previous_clicked(self, b):
        self.save_state()
        self.current_pos -= 1
        self.__perform_action()

    def __on_next_clicked(self, b):
        self.save_state()
        self.current_pos += 1
        self.__perform_action()

    def __on_save_clicked(self, b):
        self.save_state()
        image_record, _ = self.__get_image_record()
        path_img = os.path.join(self.output_path, 'JPEGImages', image_record['filename'])
        self.save_function(path_img)

    def __selected_index(self, t):
        if t['owner'].value is None or t['name'] != 'value':
            return
        self.current_pos = t['new'] - 1
        self.__perform_action()

    def start_classification(self):
        if self.max_pos < self.current_pos:
            print("No available images")
            return
        display(self.all_widgets)
        self.__perform_action()

    def print_statistics(self):
        counter = defaultdict(int)
        for record in self.dataset["images"]:
            for c in record["categories"]:
                counter[c] += 1
        table = []
        for c in counter:
            table.append([self.mapping["categories_id"][c], counter[c]])
        table = sorted(table, key=lambda x: x[0])
        print(tabulate(table, headers=['Class name', 'Annotated images']))
Beispiel #2
0
class Iterator:
    def __init__(self,
                 images,
                 name="iterator",
                 show_name=True,
                 show_axis=False,
                 show_random=True,
                 fig_size=(10, 10),
                 buttons_vertical=False,
                 image_display_function=None):
        if len(images) == 0:
            raise Exception("No images provided")

        self.show_axis = show_axis
        self.name = name
        self.show_name = show_name
        self.show_random = show_random
        self.images = images
        self.max_pos = len(self.images) - 1
        self.pos = 0
        self.fig_size = fig_size
        self.buttons_vertical = buttons_vertical

        if image_display_function is None:
            self.image_display_function = self.__show_image
        else:
            self.image_display_function = image_display_function

        self.previous_button = self.__create_button("Previous",
                                                    (self.pos == 0),
                                                    self.__on_previous_clicked)
        self.next_button = self.__create_button("Next",
                                                (self.pos == self.max_pos),
                                                self.__on_next_clicked)
        self.save_button = self.__create_button("Save", False,
                                                self.__on_save_clicked)
        self.save_function = self.__save_function  # save_function

        buttons = [self.previous_button, self.next_button]

        if self.show_random:
            self.random_button = self.__create_button("Random", False,
                                                      self.__on_random_clicked)
            buttons.append(self.random_button)

        buttons.append(self.save_button)

        label_total = Label(value='/ {}'.format(len(self.images)))
        self.text_index = BoundedIntText(value=1, min=1, max=len(self.images))
        self.text_index.layout.width = '80px'
        self.text_index.layout.height = '35px'
        self.text_index.observe(self.__selected_index)
        self.out = Output()
        self.out.add_class(name)

        if self.buttons_vertical:
            self.all_widgets = HBox(children=[
                VBox(children=[HBox([self.text_index, label_total])] +
                     buttons), self.out
            ])
        else:
            self.all_widgets = VBox(children=[
                HBox([self.text_index, label_total]),
                HBox(children=buttons), self.out
            ])
        ## loading js library to perform html screenshots
        j_code = """
                require.config({
                    paths: {
                        html2canvas: "https://html2canvas.hertzen.com/dist/html2canvas.min"
                    }
                });
            """
        display(Javascript(j_code))

    def __create_button(self, description, disabled, function):
        button = Button(description=description)
        button.disabled = disabled
        button.on_click(function)
        return button

    def __show_image(self, image_path, index):
        img = Image.open(image_path)
        if self.show_name:
            print(os.path.basename(image_path))
        plt.figure(figsize=self.fig_size)
        if not self.show_axis:
            plt.axis('off')
        plt.imshow(img)
        plt.show()

    def __save_function(self, image_path, index):
        img_name = os.path.basename(image_path).split('.')[0]
        j_code = """
            require(["html2canvas"], function(html2canvas) {
                var element = $(".p-Widget.jupyter-widgets-output-area.output_wrapper.$it_name$")[0];
                console.log(element);
                 html2canvas(element).then(function (canvas) { 
                    var myImage = canvas.toDataURL(); 
                    var a = document.createElement("a"); 
                    a.href = myImage; 
                    a.download = "$img_name$.png"; 
                    a.click(); 
                    a.remove(); 
                });
            });
            """
        j_code = j_code.replace('$it_name$', self.name)
        j_code = j_code.replace('$img_name$', img_name)
        tmp_out = Output()
        with tmp_out:
            display(Javascript(j_code))
            tmp_out.clear_output()

    def __on_next_clicked(self, b):
        self.pos += 1
        self.__perform_action(self.pos, self.max_pos)

    def __on_save_clicked(self, b):
        self.save_function(self.images[self.pos], self.pos)

    def __perform_action(self, index, max_pos):
        self.next_button.disabled = (index == max_pos)
        self.previous_button.disabled = (index == 0)

        with self.out:
            self.out.clear_output()
        with self.out:
            self.image_display_function(self.images[index], index)

        self.text_index.unobserve(self.__selected_index)
        self.text_index.value = index + 1
        self.text_index.observe(self.__selected_index)

    def __on_previous_clicked(self, b):
        self.pos -= 1
        self.__perform_action(self.pos, self.max_pos)

    def __on_random_clicked(self, b):
        self.pos = random.randint(0, self.max_pos)
        self.__perform_action(self.pos, self.max_pos)

    def __selected_index(self, t):
        if t['owner'].value is None or t['name'] != 'value':
            return
        self.pos = t['new'] - 1
        self.__perform_action(self.pos, self.max_pos)

    def start_iteration(self):
        if self.max_pos < self.pos:
            print("No available images")
            return

        display(self.all_widgets)
        self.__perform_action(self.pos, self.max_pos)
Beispiel #3
0
class Annotator:

    def __init__(self,
                 dataset,
                 metrics_and_values,  # properties{ name: ([possible values], just_one_value = True)}
                 output_path=None,
                 show_name=True,
                 show_axis=False,
                 fig_size=(10, 10),
                 buttons_vertical=False,
                 image_display_function=None,
                 classes_to_annotate=None
                 ):

        self.dataset_orig = dataset
        self.metrics_and_values = metrics_and_values
        self.show_axis = show_axis
        self.name = (self.dataset_orig.dataset_root_param.split('/')[-1]).split('.')[0]  # get_original_file_name
        self.show_name = show_name
        if output_path is None:
            splitted_array = self.dataset_orig.dataset_root_param.split('/')
            n = len(splitted_array)
            self.output_directory = os.path.join(*(splitted_array[0:n - 1]))
        else:
            self.output_directory = output_path

        if classes_to_annotate is None:  # if classes_to_annotate is None, all the classes would be annotated
            self.classes_to_annotate = self.dataset_orig.get_categories_names()  # otherwise, the only the classes in the list

        self.file_path_for_json = os.path.join("/", self.output_directory, self.name + "_ANNOTATED.json")
        print("New dataset with meta_annotations {}".format(self.file_path_for_json))
        self.mapping = None
        self.objects = self.dataset_orig.get_annotations_from_class_list(self.classes_to_annotate)
        self.max_pos = len(self.objects) - 1
        self.current_pos = 0
        self.mapping, self.dataset_annotated = self.__create_results_dict(self.file_path_for_json)

        self.fig_size = fig_size
        self.buttons_vertical = buttons_vertical

        if image_display_function is None:
            self.image_display_function = self.__show_image
        else:
            self.image_display_function = image_display_function

        # create buttons
        self.previous_button = self.__create_button("Previous", (self.current_pos == 0), self.__on_previous_clicked)
        self.next_button = self.__create_button("Next", (self.current_pos == self.max_pos), self.__on_next_clicked)
        self.save_button = self.__create_button("Download", False, self.__on_save_clicked)
        self.save_function = self.__save_function  # save_function
        self.current_image = {}
        buttons = [self.previous_button, self.next_button]
        buttons.append(self.save_button)

        label_total = Label(value='/ {}'.format(len(self.objects)))
        self.text_index = BoundedIntText(value=1, min=1, max=len(self.objects))
        self.text_index.layout.width = '80px'
        self.text_index.layout.height = '35px'
        self.text_index.observe(self.__selected_index)
        self.out = Output()
        self.out.add_class("my_canvas_class")

        self.checkboxes = {}
        self.radiobuttons = {}
        self.bounded_text = {}

        output_layout = []
        for k_name, v in self.metrics_and_values.items():
            if MetaPropertiesTypes.META_PROP_UNIQUE == v[1]:  # radiobutton
                self.radiobuttons[k_name] = RadioButtons(name=k_name, options=v[0],
                                                         disabled=False,
                                                         indent=False)
            elif MetaPropertiesTypes.META_PROP_COMPOUND == v[1]:  # checkbox
                self.checkboxes[k_name] = [Checkbox(False, description='{}'.format(prop_name),
                                                    indent=False, name=k_name) for prop_name in v[0]]
            elif MetaPropertiesTypes.META_PROP_CONTINUE == v[1]:
                self.bounded_text[k_name] = BoundedFloatText(value=v[0][0], min=v[0][0], max=v[0][1])

        self.check_radio_boxes_layout = {}

        for rb_k, rb_v in self.radiobuttons.items():
            rb_v.layout.width = '180px'
            rb_v.observe(self.__checkbox_changed)
            rb_v.add_class(rb_k)
            html_title = HTML(value="<b>" + rb_k + "</b>")
            self.check_radio_boxes_layout[rb_k] = VBox([rb_v])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[rb_k]]))

        for cb_k, cb_i in self.checkboxes.items():
            for cb in cb_i:
                cb.layout.width = '180px'
                cb.observe(self.__checkbox_changed)
                cb.add_class(cb_k)
            html_title = HTML(value="<b>" + cb_k + "</b>")
            self.check_radio_boxes_layout[cb_k] = VBox(children=[cb for cb in cb_i])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[cb_k]]))

        for bf_k, bf in self.bounded_text.items():
            bf.layout.width = '80px'
            bf.layout.height = '35px'
            bf.observe(self.__checkbox_changed)
            bf.add_class(bf_k)
            html_title = HTML(value="<b>" + bf_k + "</b>")
            self.check_radio_boxes_layout[bf_k] = VBox([bf])
            output_layout.append(VBox([html_title, self.check_radio_boxes_layout[bf_k]]))

        self.all_widgets = VBox(children=
                                [HBox([self.text_index, label_total]),
                                 HBox(buttons),
                                 HBox([self.out,
                                 VBox(output_layout)])])

        ## loading js library to perform html screenshots
        j_code = """
                require.config({
                    paths: {
                        html2canvas: "https://html2canvas.hertzen.com/dist/html2canvas.min"
                    }
                });
            """
        display(Javascript(j_code))

    def __add_annotation_to_mapping(self, ann):
        cat_name = self.dataset_orig.get_category_name_from_id(ann['category_id'])
        if ann['id'] not in self.mapping['annotated_ids']:  # o nly adds category counter if it was not annotate
            self.mapping['categories_counter'][cat_name] += 1

        self.mapping['annotated_ids'].add(ann['id'])

    def __create_results_dict(self, file_path):
        mapping = {}
        mapping["annotated_ids"] = set()  # key: object_id from dataset, values=[(annotations done)]
        mapping["categories_counter"] = dict.fromkeys([c for c in self.dataset_orig.get_categories_names()], 0)
        self.mapping = mapping

        if not os.path.exists("/" + file_path): #it does exist __ANNOTATED in the output directory
            with open(self.dataset_orig.dataset_root_param, 'r') as input_json_file:
                dataset_annotated = json.load(input_json_file)
                input_json_file.close()
            #take the same metaproperties already in file if it's not empty (like a new file)
            meta_prop = dataset_annotated['meta_properties'] if 'meta_properties' in dataset_annotated.keys() else []

            # adds the new annotations categories to dataset if it doesn't exist
            for k_name, v in self.metrics_and_values.items():
                new_mp_to_append = {
                    "name": k_name,
                    "type": v[1].value,
                    "values": [p for p in v[0]].sort()
                }

                names_prop_in_file = {m_p['name']: m_p for m_i, m_p in enumerate(dataset_annotated[
                                                                                     'meta_properties'])} if 'meta_properties' in dataset_annotated.keys() else None

                if 'meta_properties' not in dataset_annotated.keys():  # it is a new file
                    meta_prop.append(new_mp_to_append)
                    dataset_annotated['meta_properties'] = []

                elif names_prop_in_file is not None and k_name not in names_prop_in_file.keys():
                    # if there is a property with the same in meta_properties, it must be the same structure as the one proposed
                    meta_prop.append(new_mp_to_append)
                    self.__update_annotation_counter_and_current_pos(dataset_annotated)

                elif names_prop_in_file is not None and k_name in names_prop_in_file.keys() and \
                        names_prop_in_file[k_name] == new_mp_to_append:
                    #we don't append because it's already there
                    self.__update_annotation_counter_and_current_pos(dataset_annotated)

                else:
                    raise NameError("An annotation with the same name {} "
                                    "already exist in dataset {}, and it has different structure. Check properties.".format(
                        k_name, self.dataset_orig.dataset_root_param))

                #if k_name is in name_props_in_file and it's the same structure. No update is done.
            dataset_annotated['meta_properties'] = dataset_annotated['meta_properties'] + meta_prop
        else:
            with open(file_path, 'r') as classification_file:
                dataset_annotated = json.load(classification_file)
                classification_file.close()

            self.__update_annotation_counter_and_current_pos(dataset_annotated)
        return mapping, dataset_annotated

    def __update_annotation_counter_and_current_pos(self, dataset_annotated):
        prop_names = set(self.metrics_and_values.keys())
        last_ann_id = self.objects[0]['id']
        for ann in dataset_annotated['annotations']:
            if prop_names.issubset(
                    set(ann.keys())):  # we consider that it was annotated when all the props are subset of keys
                self.__add_annotation_to_mapping(ann)
                last_ann_id = ann['id']  # to get the last index in self.object so we can update the current pos in the last one anotated
            else:
                break
        self.current_pos = next(i for i, a in enumerate(self.objects) if a['id'] == last_ann_id)

    def __checkbox_changed(self, b):
        if b['owner'].value is None or b['name'] != 'value':
            return

        class_name = b['owner'].description
        value = b['owner'].value
        annotation_name = b['owner']._dom_classes[0]

        ann_id = self.objects[self.current_pos]['id']
        # image_id = self.objects[self.current_pos]['image_id']
        for ann in self.dataset_annotated['annotations']:
            if ann['id'] == ann_id:
                break
        if self.metrics_and_values[annotation_name][1] in [MetaPropertiesTypes.META_PROP_COMPOUND]:
            if annotation_name not in ann.keys():
                ann[annotation_name] = {p: 0 for p in self.metrics_and_values[annotation_name][0]}
            ann[annotation_name][class_name] = int(value)
        else:  # UNIQUE VALUE
            ann[annotation_name] = value

    def __create_button(self, description, disabled, function):
        button = Button(description=description)
        button.disabled = disabled
        button.on_click(function)
        return button

    def __show_image(self, image_record, ann_key):
        #   read img from path and show it
        path_img = os.path.join(self.dataset_orig.images_abs_path, image_record['file_name'])
        img = Image.open(path_img)
        if self.show_name:
            print(os.path.basename(path_img) + '. Class: {} [class_id={}]'.format(
                self.dataset_orig.get_category_name_from_id(self.objects[self.current_pos]['category_id']),
                self.objects[self.current_pos]['category_id']))
        plt.figure(figsize=self.fig_size)

        if not self.show_axis:
            plt.axis('off')
        plt.imshow(img)

        # draw the bbox from the object onum
        ax = plt.gca()
        class_colors = cm.rainbow(np.linspace(0, 1, len(self.dataset_orig.get_categories_names())))

        annotation = self.dataset_orig.coco_lib.anns[ann_key]
        object_class_name = self.dataset_orig.get_category_name_from_id(annotation['category_id'])
        c = class_colors[self.dataset_orig.get_categories_names().index(object_class_name)]
        if not self.dataset_orig.is_segmentation:
            [bbox_x1, bbox_y1, diff_x, diff_y] = annotation['bbox']
            bbox_x2 = bbox_x1 + diff_x
            bbox_y2 = bbox_y1 + diff_y
            poly = [[bbox_x1, bbox_y1], [bbox_x1, bbox_y2], [bbox_x2, bbox_y2],
                    [bbox_x2, bbox_y1]]
            np_poly = np.array(poly).reshape((4, 2))

            # draws the bbox
            ax.add_patch(
                Polygon(np_poly, linestyle='-', facecolor=(c[0], c[1], c[2], 0.0),
                        edgecolor=(c[0], c[1], c[2], 1.0), linewidth=2))
        else:
            seg_points = annotation['segmentation']
            for pol in seg_points:
                poly = [[float(pol[i]), float(pol[i + 1])] for i in range(0, len(pol), 2)]
                np_poly = np.array(poly)  # .reshape((len(pol), 2))
                ax.add_patch(
                    Polygon(np_poly, linestyle='-', facecolor=(c[0], c[1], c[2], 0.25),
                            edgecolor=(c[0], c[1], c[2], 1.0), linewidth=2))
            # set the first XY point for printing the text
            bbox_x1 = seg_points[0][0];
            bbox_y1 = seg_points[0][1]

        #  write the class name in bbox
        ax.text(x=bbox_x1, y=bbox_y1, s=object_class_name, color='white', fontsize=9, horizontalalignment='left',
                verticalalignment='top',
                bbox=dict(facecolor=(c[0], c[1], c[2], 0.5)))
        plt.show()

    def save_state(self):
        w = SafeWriter(os.path.join(self.file_path_for_json), "w")
        w.write(json.dumps(self.dataset_annotated))
        w.close()
        self.__add_annotation_to_mapping(next(ann_dat for ann_dat in self.dataset_annotated['annotations']
                                              if ann_dat['id'] == self.objects[self.current_pos]['id']))

    def __save_function(self, image_path):
        img_name = os.path.basename(image_path).split('.')[0]
        j_code = """
            require(["html2canvas"], function(html2canvas) {
                var element = $(".p-Widget.jupyter-widgets-output-area.output_wrapper.$it_name$")[0];
                console.log(element);
                 html2canvas(element).then(function (canvas) { 
                    var myImage = canvas.toDataURL(); 
                    var a = document.createElement("a"); 
                    a.href = myImage; 
                    a.download = "$img_name$.png"; 
                    a.click(); 
                    a.remove(); 
                });
            });
            """
        j_code = j_code.replace('$it_name$', "my_canvas_class")
        j_code = j_code.replace('$img_name$', img_name)
        tmp_out = Output()
        with tmp_out:
            display(Javascript(j_code))
            tmp_out.clear_output()

    def __perform_action(self):
        self.next_button.disabled = (self.current_pos == self.max_pos)
        self.previous_button.disabled = (self.current_pos == 0)

        current_class_id = self.objects[self.current_pos]['category_id']
        current_ann = next(
            a for a in self.dataset_annotated['annotations'] if a['id'] == self.objects[self.current_pos]['id'])
        # current_gt_id = self.objects[self.current_pos]['gt_id']

        for m_k, m_v in self.metrics_and_values.items():
            if m_v[1] == MetaPropertiesTypes.META_PROP_UNIQUE:  # radiobutton
                self.radiobuttons[m_k].unobserve(self.__checkbox_changed)
                self.radiobuttons[m_k].value = current_ann[m_k] if m_k in current_ann.keys() else None
                self.radiobuttons[m_k].observe(self.__checkbox_changed)
            elif m_v[1] == MetaPropertiesTypes.META_PROP_COMPOUND:  # checkbox
                for cb_i, cb_v in enumerate(self.checkboxes[m_k]):
                    cb_v.unobserve(self.__checkbox_changed)
                    cb_v.value = bool(current_ann[m_k][cb_v.description]) if m_k in current_ann.keys() else False
                    cb_v.observe(self.__checkbox_changed)
            elif m_v[1] == MetaPropertiesTypes.META_PROP_CONTINUE:  # textbound
                self.bounded_text[m_k].unobserve(self.__checkbox_changed)
                self.bounded_text[m_k].value = float(current_ann[m_k]) if m_k in current_ann.keys() else \
                    self.bounded_text[m_k].min
                self.bounded_text[m_k].observe(self.__checkbox_changed)

        with self.out:
            self.out.clear_output()
            image_record, ann_key = self.__get_image_record()
            self.image_display_function(image_record, ann_key)

        self.text_index.unobserve(self.__selected_index)
        self.text_index.value = self.current_pos + 1
        self.text_index.observe(self.__selected_index)

    def __get_image_record(self):
        # current_class_id = self.objects[self.current_pos]['category_id']
        current_image_id = self.objects[self.current_pos]['image_id']

        ann_key = self.objects[self.current_pos]['id']
        img_record = self.dataset_orig.coco_lib.imgs[current_image_id]
        return img_record, ann_key

    def __on_previous_clicked(self, b):
        self.save_state()
        self.current_pos -= 1
        self.__perform_action()

    def __on_next_clicked(self, b):
        self.save_state()
        self.current_pos += 1
        self.__perform_action()

    def __on_save_clicked(self, b):
        self.save_state()
        image_record, _ = self.__get_image_record()
        path_img = os.path.join(self.output_directory, 'JPEGImages', image_record['file_name'])
        self.save_function(path_img)

    def __selected_index(self, t):
        if t['owner'].value is None or t['name'] != 'value':
            return
        self.current_pos = t['new'] - 1
        self.__perform_action()

    def start_annotation(self):
        if self.max_pos < self.current_pos:
            print("No available images")
            return
        display(self.all_widgets)
        self.__perform_action()

    def print_statistics(self):
        table = []
        total = 0
        for c_k, c_number in self.mapping["categories_counter"].items():
            table.append([c_k, c_number])
            total += c_number
        table = sorted(table, key=lambda x: x[0])
        table.append(['Total', '{}/{}'.format(total, len(self.objects))])
        print(tabulate(table, headers=['Class name', 'Annotated objects']))
Beispiel #4
0
class DatasetAnnotatorClassification:
    supported_types = [
        TaskType.CLASSIFICATION_BINARY, TaskType.CLASSIFICATION_SINGLE_LABEL,
        TaskType.CLASSIFICATION_MULTI_LABEL
    ]

    def __init__(self,
                 task_type,
                 observations,
                 output_path,
                 name,
                 classes,
                 show_name=True,
                 show_axis=False,
                 fig_size=(10, 10),
                 buttons_vertical=False,
                 custom_display_function=None,
                 is_image=True):

        if task_type not in self.supported_types:
            raise Exception(labels_str.warn_task_not_supported)

        if len(observations) == 0:
            raise Exception(labels_str.warn_no_images)

        num_classes = len(classes)
        if num_classes <= 1:
            raise Exception(labels_str.warn_little_classes)

        elif len(
                classes
        ) > 2 and task_type.value == TaskType.CLASSIFICATION_BINARY.value:
            raise Exception(labels_str.warn_binary_only_two)

        self.is_image = is_image
        self.key = "path" if self.is_image else "observation"
        if not self.is_image and custom_display_function is None:
            raise Exception(labels_str.warn_display_function_needed)

        self.task_type = task_type
        self.show_axis = show_axis
        self.name = name
        self.show_name = show_name
        self.output_path = output_path
        self.file_path = os.path.join(self.output_path, self.name + ".json")
        print(labels_str.info_ds_output + self.file_path)
        self.mapping, self.dataset = self.__create_results_dict(
            self.file_path, classes)

        self.classes = list(self.mapping["categories_id"].values())

        if len(
                self.classes
        ) > 2 and task_type.value == TaskType.CLASSIFICATION_BINARY.value:
            raise Exception(labels_str.warn_binary_only_two +
                            " ".join(self.classes))

        self.observations = observations
        self.max_pos = len(self.observations) - 1
        self.pos = 0
        self.fig_size = fig_size
        self.buttons_vertical = buttons_vertical

        if custom_display_function is None:
            self.image_display_function = self.__show_image
        else:
            self.image_display_function = custom_display_function

        self.previous_button = self.__create_button(labels_str.str_btn_prev,
                                                    (self.pos == 0),
                                                    self.__on_previous_clicked)
        self.next_button = self.__create_button(labels_str.str_btn_next,
                                                (self.pos == self.max_pos),
                                                self.__on_next_clicked)
        self.save_button = self.__create_button(labels_str.str_btn_download,
                                                False, self.__on_save_clicked)
        self.save_function = self.__save_function  # save_function

        buttons = [self.previous_button, self.next_button, self.save_button]

        label_total = Label(value='/ {}'.format(len(self.observations)))
        self.text_index = BoundedIntText(value=1,
                                         min=1,
                                         max=len(self.observations))
        self.text_index.layout.width = '80px'
        self.text_index.layout.height = '35px'
        self.text_index.observe(self.__selected_index)
        self.out = Output()
        self.out.add_class(name)

        if self.__is_multilabel():
            self.checkboxes = [
                Checkbox(False,
                         description='{}'.format(self.classes[i]),
                         indent=False) for i in range(len(self.classes))
            ]
            for cb in self.checkboxes:
                cb.layout.width = '180px'
                cb.observe(self.__checkbox_changed)
            self.checkboxes_layout = VBox(
                children=[cb for cb in self.checkboxes])
        else:
            self.checkboxes = RadioButtons(options=self.classes,
                                           disabled=False,
                                           indent=False)
            self.checkboxes.layout.width = '180px'
            self.checkboxes.observe(self.__checkbox_changed)
            self.checkboxes_layout = VBox(children=[self.checkboxes])

        output_layout = HBox(children=[self.out, self.checkboxes_layout])
        if self.buttons_vertical:
            self.all_widgets = HBox(children=[
                VBox(children=[HBox([self.text_index, label_total])] +
                     buttons), output_layout
            ])
        else:
            self.all_widgets = VBox(children=[
                HBox([self.text_index, label_total]),
                HBox(children=buttons), output_layout
            ])

        ## loading js library to perform html screenshots
        j_code = """
                require.config({
                    paths: {
                        html2canvas: "https://html2canvas.hertzen.com/dist/html2canvas.min"
                    }
                });
            """
        display(Javascript(j_code))

    def __create_results_dict(self, file_path, cc):
        mapping = {
            "categories_id": {},
            "categories_name": {},
            "observations": {}
        }

        if not os.path.exists(file_path):
            dataset = {'categories': [], "observations": []}
            for index, c in enumerate(cc):
                category = {"supercategory": c, "name": c, "id": index + 1}
                dataset["categories"].append(category)
        else:
            with open(file_path, 'r') as classification_file:
                dataset = json.load(classification_file)
            for index, img in enumerate(dataset['observations']):
                mapping['observations'][img[self.key]] = index

        for c in dataset['categories']:
            mapping['categories_id'][c["id"]] = c["name"]
            mapping['categories_name'][c["name"]] = c["id"]
        index_categories = len(dataset['categories']) + 1

        for c in cc:
            if not c in mapping['categories_name'].keys():
                mapping['categories_id'][index_categories] = c
                mapping['categories_name'][c] = index_categories
                category = {
                    "supercategory": c,
                    "name": c,
                    "id": index_categories
                }
                dataset["categories"].append(category)
                index_categories += 1

        return mapping, dataset

    def __checkbox_changed(self, b):

        if b['owner'].value is None or b['name'] != 'value':
            return

        class_name = b['owner'].description
        value = b['owner'].value
        current_index = self.mapping["observations"][self.observations[
            self.pos]]

        if self.__is_multilabel():
            class_index = self.mapping["categories_name"][class_name]
            if not class_index in self.dataset["observations"][current_index][
                    "categories"] and value:
                self.dataset["observations"][current_index][
                    "categories"].append(class_index)
            if class_index in self.dataset["observations"][current_index][
                    "categories"] and not value:
                self.dataset["observations"][current_index][
                    "categories"].remove(class_index)
        else:
            class_index = self.mapping["categories_name"][value]
            self.dataset["observations"][current_index][
                "category"] = class_index

        if self.pos == self.max_pos:
            self.save_state()

    def __is_multilabel(self):
        return TaskType.CLASSIFICATION_MULTI_LABEL.value == self.task_type.value

    def __create_button(self, description, disabled, function):
        button = Button(description=description)
        button.disabled = disabled
        button.on_click(function)
        return button

    def __show_image(self, image_record):
        if not 'path' in image_record:
            print("missing path")
        if not os.path.exists(image_record['path']):
            print("Image cannot be load" + image_record['path'])

        img = Image.open(image_record['path'])
        if self.show_name:
            print(os.path.basename(image_record['path']))
        plt.figure(figsize=self.fig_size)
        if not self.show_axis:
            plt.axis('off')
        plt.imshow(img)
        plt.show()

    def save_state(self):
        with open(self.file_path, 'w') as output_file:
            json.dump(self.dataset, output_file, indent=4)

    def __save_function(self, image_path, index):
        img_name = os.path.basename(image_path).split('.')[0]
        j_code = """
            require(["html2canvas"], function(html2canvas) {
                var element = $(".p-Widget.jupyter-widgets-output-area.output_wrapper.$it_name$")[0];
                console.log(element);
                 html2canvas(element).then(function (canvas) { 
                    var myImage = canvas.toDataURL(); 
                    var a = document.createElement("a"); 
                    a.href = myImage; 
                    a.download = "$img_name$.png"; 
                    a.click(); 
                    a.remove(); 
                });
            });
            """
        j_code = j_code.replace('$it_name$', self.name)
        j_code = j_code.replace('$img_name$', img_name)
        tmp_out = Output()
        with tmp_out:
            display(Javascript(j_code))
            tmp_out.clear_output()

    def __perform_action(self):

        if not self.observations[
                self.pos] in self.mapping["observations"].keys():
            observation = {self.key: self.observations[self.pos]}
            if self.is_image:
                observation["file_name"] = os.path.basename(
                    self.observations[self.pos])

            observation["id"] = len(self.mapping["observations"]) + 1
            if self.__is_multilabel():
                observation["categories"] = []
            self.dataset["observations"].append(observation)

            self.mapping["observations"][observation[self.key]] = len(
                self.dataset["observations"]) - 1

        current_index = self.mapping["observations"][self.observations[
            self.pos]]
        self.next_button.disabled = (self.pos == self.max_pos)
        self.previous_button.disabled = (self.pos == 0)

        if self.__is_multilabel():
            for cb in self.checkboxes:
                cb.unobserve(self.__checkbox_changed)
                if not "categories" in self.dataset["observations"][
                        current_index]:
                    self.dataset["observations"][current_index][
                        "categories"] = []
                categories = self.dataset["observations"][current_index][
                    "categories"]
                cb.value = self.mapping["categories_name"][
                    cb.description] in categories
                cb.observe(self.__checkbox_changed)
        else:
            self.checkboxes.unobserve(self.__checkbox_changed)
            obs = self.dataset["observations"][current_index]
            category = obs["category"] if "category" in obs else None

            if category:
                for k in self.mapping["categories_name"]:
                    if self.mapping["categories_name"][k] == category:
                        category = k
                        break
            self.checkboxes.value = category
            self.checkboxes.observe(self.__checkbox_changed)

        with self.out:
            self.out.clear_output()
            self.image_display_function(
                self.dataset["observations"][current_index])

        self.text_index.unobserve(self.__selected_index)
        self.text_index.value = self.pos + 1
        self.text_index.observe(self.__selected_index)

    def __on_previous_clicked(self, b):
        self.save_state()
        self.pos -= 1
        self.__perform_action()

    def __on_next_clicked(self, b):
        self.save_state()
        self.pos += 1
        self.__perform_action()

    def __on_save_clicked(self, b):
        self.save_state()
        self.save_function(self.observations[self.pos], self.pos)

    def __selected_index(self, t):
        if t['owner'].value is None or t['name'] != 'value':
            return
        self.pos = t['new'] - 1
        self.__perform_action()

    def start_classification(self):
        if self.max_pos < self.pos:
            print("No available observation")
            return
        display(self.all_widgets)
        self.__perform_action()

    def print_statistics(self):
        counter = dict()
        for c in self.mapping["categories_id"]:
            counter[c] = 0

        for record in self.dataset["observations"]:

            if self.__is_multilabel():
                if "categories" in record:
                    for c in record["categories"]:
                        counter[c] += 1
            elif "category" in record:
                counter[record["category"]] += 1

        table = []
        for c in counter:
            table.append([self.mapping["categories_id"][c], counter[c]])
        table = sorted(table, key=lambda x: x[0])

        print(
            tabulate(table,
                     headers=[
                         labels_str.info_class_name, labels_str.info_ann_images
                     ]))