예제 #1
0
def import_prediction(pred):
    if isinstance(pred, ac.ClassificationPrediction):
        scores = softmax(pred.scores)
        return (dm.Label(label_id, attributes={'score': float(score)})
                for label_id, score in enumerate(scores))
    elif isinstance(pred, ac.ArgMaxClassificationPrediction):
        return (dm.Label(int(pred.label)), )
    elif isinstance(pred, ac.CharacterRecognitionPrediction):
        return (dm.Label(int(pred.label)), )
    elif isinstance(pred,
                    (ac.DetectionPrediction, ac.ActionDetectionPrediction)):
        return (dm.Bbox(x0,
                        y0,
                        x1 - x0,
                        y1 - y0,
                        int(label_id),
                        attributes={'score': float(score)})
                for label, score, x0, y0, x1, y1 in zip(
                    pred.labels, pred.scores, pred.x_mins, pred.y_mins,
                    pred.x_maxs, pred.y_maxs))
    elif isinstance(pred, ac.DepthEstimationPrediction):
        return (dm.Mask(pred.depth_map), )  # 2d floating point mask
    # elif isinstance(pred, ac.HitRatioPrediction):
    #     -
    elif isinstance(pred, ac.ImageInpaintingPrediction):
        return (dm.Mask(pred.value), )  # an image
    # elif isinstance(pred, ac.MultiLabelRecognitionPrediction):
    #     -
    # elif isinstance(pred, ac.MachineTranslationPrediction):
    #     -
    # elif isinstance(pred, ac.QuestionAnsweringPrediction):
    #     -
    # elif isinstance(pred, ac.PoseEstimation3dPrediction):
    #     -
    # elif isinstance(pred, ac.PoseEstimationPrediction):
    #     -
    # elif isinstance(pred, ac.RegressionPrediction):
    #     -
    else:
        raise NotImplementedError("Can't convert %s" % type(pred))
예제 #2
0
    def _read_cvat_anno(self, cvat_frame_anno, task_data):
        item_anno = []

        categories = self.categories()
        label_cat = categories[datumaro.AnnotationType.label]

        def map_label(name):
            return label_cat.find(name)[0]

        label_attrs = {
            label['name']: label['attributes']
            for _, label in task_data.meta['task']['labels']
        }

        def convert_attrs(label, cvat_attrs):
            cvat_attrs = {a.name: a.value for a in cvat_attrs}
            dm_attr = dict()
            for _, a_desc in label_attrs[label]:
                a_name = a_desc['name']
                a_value = cvat_attrs.get(a_name, a_desc['default_value'])
                try:
                    if a_desc['input_type'] == AttributeType.NUMBER:
                        a_value = float(a_value)
                    elif a_desc['input_type'] == AttributeType.CHECKBOX:
                        a_value = (a_value.lower() == 'true')
                    dm_attr[a_name] = a_value
                except Exception as e:
                    raise Exception(
                        "Failed to convert attribute '%s'='%s': %s" %
                        (a_name, a_value, e))
            return dm_attr

        for tag_obj in cvat_frame_anno.tags:
            anno_group = tag_obj.group
            anno_label = map_label(tag_obj.label)
            anno_attr = convert_attrs(tag_obj.label, tag_obj.attributes)

            anno = datumaro.Label(label=anno_label,
                                  attributes=anno_attr,
                                  group=anno_group)
            item_anno.append(anno)

        for shape_obj in cvat_frame_anno.labeled_shapes:
            anno_group = shape_obj.group
            anno_label = map_label(shape_obj.label)
            anno_attr = convert_attrs(shape_obj.label, shape_obj.attributes)
            anno_attr['occluded'] = shape_obj.occluded

            if hasattr(shape_obj, 'track_id'):
                anno_attr['track_id'] = shape_obj.track_id
                anno_attr['keyframe'] = shape_obj.keyframe

            anno_points = shape_obj.points
            if shape_obj.type == ShapeType.POINTS:
                anno = datumaro.Points(anno_points,
                                       label=anno_label,
                                       attributes=anno_attr,
                                       group=anno_group,
                                       z_order=shape_obj.z_order)
            elif shape_obj.type == ShapeType.POLYLINE:
                anno = datumaro.PolyLine(anno_points,
                                         label=anno_label,
                                         attributes=anno_attr,
                                         group=anno_group,
                                         z_order=shape_obj.z_order)
            elif shape_obj.type == ShapeType.POLYGON:
                anno = datumaro.Polygon(anno_points,
                                        label=anno_label,
                                        attributes=anno_attr,
                                        group=anno_group,
                                        z_order=shape_obj.z_order)
            elif shape_obj.type == ShapeType.RECTANGLE:
                x0, y0, x1, y1 = anno_points
                anno = datumaro.Bbox(x0,
                                     y0,
                                     x1 - x0,
                                     y1 - y0,
                                     label=anno_label,
                                     attributes=anno_attr,
                                     group=anno_group,
                                     z_order=shape_obj.z_order)
            elif shape_obj.type == ShapeType.CUBOID:
                continue  # Datumaro does not support cuboids
            else:
                raise Exception("Unknown shape type '%s'" % shape_obj.type)

            item_anno.append(anno)

        return item_anno