Ejemplo n.º 1
0
 def set_ann(self, ann):
     if type(ann) is not self.annotation_class:
         raise TypeError(
             "Type of 'ann' have to be Annotation, not a {}".format(
                 type(ann)))
     dst_ann_path = self.get_ann_path()
     dump_json_file(ann.to_json(), dst_ann_path)
Ejemplo n.º 2
0
 def _save_model_snapshot(self, is_best, opt_data):
     out_dir = self.checkpoints_saver.get_dir_to_write()
     dump_json_file(self.out_config,
                    os.path.join(out_dir, TaskPaths.MODEL_CONFIG_NAME))
     self._dump_model_weights(out_dir)
     size_bytes = sly_fs.get_directory_size(out_dir)
     self.checkpoints_saver.saved(is_best, size_bytes, opt_data)
Ejemplo n.º 3
0
 def dump_json(self, path):
     '''
     Save current class object data in json format by given path
     :param path: str
     '''
     simple_dict = self.to_dict()
     dump_json_file(simple_dict, path, indent=4)
Ejemplo n.º 4
0
 def set_ann(self, item_name: str, ann):
     '''
     Save given videoannotation for given video to the appropriate folder
     :param item_name: str
     :param ann: VideoAnnotation class object, raise error if ann type is another
     '''
     if type(ann) is not self.annotation_class:
         raise TypeError(
             "Type of 'ann' have to be Annotation, not a {}".format(
                 type(ann)))
     dst_ann_path = self.get_ann_path(item_name)
     dump_json_file(ann.to_json(), dst_ann_path)
Ejemplo n.º 5
0
def download_project_objects(api: sly.Api, task_id, context, state,
                             app_logger):
    try:
        if not dir_exists(g.project_dir):
            mkdir(g.project_dir)
            project_meta_path = os.path.join(g.project_dir, "meta.json")
            g.project_meta = convert_object_tags(g.project_meta)
            project_meta_json = g.project_meta.to_json()
            dump_json_file(project_meta_json, project_meta_path)
            datasets = api.dataset.get_list(g.project_id)
            for dataset in datasets:
                ds_dir = os.path.join(g.project_dir, dataset.name)
                img_dir = os.path.join(ds_dir, "img")
                ann_dir = os.path.join(ds_dir, "ann")

                mkdir(ds_dir)
                mkdir(img_dir)
                mkdir(ann_dir)
                images_infos = api.image.get_list(dataset.id)
                download_progress = get_progress_cb(
                    progress_index, "Download project",
                    g.project_info.items_count * 2)
                for batch in sly.batched(images_infos):
                    image_ids = [image_info.id for image_info in batch]
                    image_names = [image_info.name for image_info in batch]
                    ann_infos = api.annotation.download_batch(
                        dataset.id, image_ids, progress_cb=download_progress)

                    image_nps = api.image.download_nps(
                        dataset.id, image_ids, progress_cb=download_progress)
                    anns = [
                        sly.Annotation.from_json(ann_info.annotation,
                                                 g.project_meta)
                        for ann_info in ann_infos
                    ]
                    selected_classes = get_selected_classes_from_ui(
                        state["classesSelected"])
                    crops = crop_and_resize_objects(image_nps, anns, state,
                                                    selected_classes,
                                                    image_names)
                    crop_nps, crop_anns, crop_names = unpack_crops(
                        crops, image_names)
                    crop_anns = copy_tags(crop_anns)
                    write_images(crop_nps, crop_names, img_dir)
                    dump_anns(crop_anns, crop_names, ann_dir)

            reset_progress(progress_index)

        global project_fs
        project_fs = sly.Project(g.project_dir, sly.OpenMode.READ)
        g.images_infos = create_img_infos(project_fs)
    except Exception as e:
        reset_progress(progress_index)
        raise e

    items_count = g.project_stats["objects"]["total"]["objectsInDataset"]
    train_percent = 80
    train_count = int(items_count / 100 * train_percent)
    random_split = {
        "count": {
            "total": items_count,
            "train": train_count,
            "val": items_count - train_count
        },
        "percent": {
            "total": 100,
            "train": train_percent,
            "val": 100 - train_percent
        },
        "shareImagesBetweenSplits": False,
        "sliderDisabled": False,
    }

    fields = [
        {
            "field": "data.done1",
            "payload": True
        },
        {
            "field": "state.collapsed2",
            "payload": False
        },
        {
            "field": "state.disabled2",
            "payload": False
        },
        {
            "field": "state.activeStep",
            "payload": 2
        },
        {
            "field": "data.totalImagesCount",
            "payload": items_count
        },
        {
            "field": "state.randomSplit",
            "payload": random_split
        },
    ]
    g.api.app.set_fields(g.task_id, fields)
Ejemplo n.º 6
0
def create_img_infos(project_fs):
    tag_id_map = {
        tag["name"]: tag["id"]
        for tag in project_fs.meta.tag_metas.to_json()
    }
    images_infos = []
    for dataset_fs in project_fs:
        img_info_dir = os.path.join(dataset_fs.directory, "img_info")
        mkdir(img_info_dir)
        for idx, item_name in enumerate(os.listdir(dataset_fs.item_dir)):
            item_ext = get_file_ext(item_name).lstrip(".")
            item_path = os.path.join(dataset_fs.item_dir, item_name)
            item = sly.image.read(item_path)
            h, w = item.shape[:2]
            item_size = os.path.getsize(item_path)
            created_at = datetime.fromtimestamp(
                os.stat(item_path).st_ctime,
                tz=timezone.utc).strftime("%d-%m-%Y %H:%M:%S")
            modified_at = datetime.fromtimestamp(
                os.stat(item_path).st_mtime,
                tz=timezone.utc).strftime("%d-%m-%Y %H:%M:%S")

            item_ann_path = os.path.join(dataset_fs.ann_dir,
                                         f"{item_name}.json")
            ann_json = load_json_file(item_ann_path)
            ann = sly.Annotation.from_json(ann_json, project_fs.meta)
            tags = ann.img_tags
            tags_json = tags.to_json()
            labels_count = len(ann.labels)

            tags_img_info = []
            for tag in tags_json:
                tag_info = {
                    "entityId": None,
                    "tagId": tag_id_map[tag["name"]],
                    "id": None,
                    "labelerLogin": tag["labelerLogin"],
                    "createdAt": tag["createdAt"],
                    "updatedAt": tag["updatedAt"],
                    "name": tag["name"]
                }
                tags_img_info.append(tag_info)

            item_img_info = {
                "id": idx,
                "name": item_name,
                "link": "",
                "hash": "",
                "mime": f"image/{item_ext}",
                "ext": item_ext,
                "size": item_size,
                "width": w,
                "height": h,
                "labels_count": labels_count,
                "dataset_id": dataset_fs.name,
                "created_at": created_at,
                "updated_at": modified_at,
                "meta": {},
                "path_original": "",
                "full_storage_url": "",
                "tags": tags_img_info
            }
            save_path = os.path.join(img_info_dir, f"{item_name}.json")
            dump_json_file(item_img_info, save_path)
            images_infos.append(item_img_info)
    return images_infos
Ejemplo n.º 7
0
def dump_anns(crop_anns, crop_names, ann_dir):
    for crop_ann, crop_name in zip(crop_anns, crop_names):
        ann_path = f"{os.path.join(ann_dir, crop_name)}.json"
        ann_json = crop_ann.to_json()
        dump_json_file(ann_json, ann_path)
Ejemplo n.º 8
0
def download_pointcloud_episode_project(api,
                                        project_id,
                                        dest_dir,
                                        dataset_ids=None,
                                        download_pcd=True,
                                        download_realated_images=True,
                                        download_annotations=True,
                                        log_progress=False,
                                        batch_size=10):
    key_id_map = KeyIdMap()
    project_fs = PointcloudEpisodeProject(dest_dir, OpenMode.CREATE)
    meta = ProjectMeta.from_json(api.project.get_meta(project_id))
    project_fs.set_meta(meta)

    datasets_infos = []
    if dataset_ids is not None:
        for ds_id in dataset_ids:
            datasets_infos.append(api.dataset.get_info_by_id(ds_id))
    else:
        datasets_infos = api.dataset.get_list(project_id)

    for dataset in datasets_infos:
        dataset_fs = project_fs.create_dataset(dataset.name)
        pointclouds = api.pointcloud_episode.get_list(dataset.id)

        if download_annotations:
            # Download annotation to project_path/dataset_path/annotation.json
            ann_json = api.pointcloud_episode.annotation.download(dataset.id)
            annotation = dataset_fs.annotation_class.from_json(
                ann_json, meta, key_id_map)
            dataset_fs.set_ann(annotation)

            # frames --> pointcloud mapping to project_path/dataset_path/frame_pointcloud_map.json
            frame_name_map = api.pointcloud_episode.get_frame_name_map(
                dataset.id)
            frame_pointcloud_map_path = dataset_fs.get_frame_pointcloud_map_path(
            )
            dump_json_file(frame_name_map, frame_pointcloud_map_path)

        # Download data
        if log_progress:
            ds_progress = Progress('Downloading dataset: {!r}'.format(
                dataset.name),
                                   total_cnt=len(pointclouds))

        for batch in batched(pointclouds, batch_size=batch_size):
            pointcloud_ids = [pointcloud_info.id for pointcloud_info in batch]
            pointcloud_names = [
                pointcloud_info.name for pointcloud_info in batch
            ]

            for pointcloud_id, pointcloud_name in zip(pointcloud_ids,
                                                      pointcloud_names):
                pointcloud_file_path = dataset_fs.generate_item_path(
                    pointcloud_name)
                if download_pcd is True:
                    api.pointcloud_episode.download_path(
                        pointcloud_id, pointcloud_file_path)
                else:
                    touch(pointcloud_file_path)

                if download_realated_images:
                    related_images_path = dataset_fs.get_related_images_path(
                        pointcloud_name)
                    related_images = api.pointcloud_episode.get_list_related_images(
                        pointcloud_id)
                    for rimage_info in related_images:
                        name = rimage_info[ApiField.NAME]
                        rimage_id = rimage_info[ApiField.ID]

                        path_img = os.path.join(related_images_path, name)
                        path_json = os.path.join(related_images_path,
                                                 name + ".json")

                        api.pointcloud_episode.download_related_image(
                            rimage_id, path_img)
                        dump_json_file(rimage_info, path_json)

                dataset_fs.add_item_file(pointcloud_name,
                                         pointcloud_file_path,
                                         _validate_item=False)
            if log_progress:
                ds_progress.iters_done_report(len(batch))

    project_fs.set_key_id_map(key_id_map)
Ejemplo n.º 9
0
def download_pointcloud_project(api,
                                project_id,
                                dest_dir,
                                dataset_ids=None,
                                download_items=True,
                                log_progress=False):
    LOG_BATCH_SIZE = 1

    key_id_map = KeyIdMap()

    project_fs = PointcloudProject(dest_dir, OpenMode.CREATE)

    meta = ProjectMeta.from_json(api.project.get_meta(project_id))
    project_fs.set_meta(meta)

    datasets_infos = []
    if dataset_ids is not None:
        for ds_id in dataset_ids:
            datasets_infos.append(api.dataset.get_info_by_id(ds_id))
    else:
        datasets_infos = api.dataset.get_list(project_id)

    for dataset in datasets_infos:
        dataset_fs = project_fs.create_dataset(dataset.name)
        pointclouds = api.pointcloud.get_list(dataset.id)

        ds_progress = None
        if log_progress:
            ds_progress = Progress('Downloading dataset: {!r}'.format(
                dataset.name),
                                   total_cnt=len(pointclouds))
        for batch in batched(pointclouds, batch_size=LOG_BATCH_SIZE):
            pointcloud_ids = [pointcloud_info.id for pointcloud_info in batch]
            pointcloud_names = [
                pointcloud_info.name for pointcloud_info in batch
            ]

            ann_jsons = api.pointcloud.annotation.download_bulk(
                dataset.id, pointcloud_ids)

            for pointcloud_id, pointcloud_name, ann_json in zip(
                    pointcloud_ids, pointcloud_names, ann_jsons):
                if pointcloud_name != ann_json[ApiField.NAME]:
                    raise RuntimeError(
                        "Error in api.video.annotation.download_batch: broken order"
                    )

                pointcloud_file_path = dataset_fs.generate_item_path(
                    pointcloud_name)
                if download_items is True:
                    api.pointcloud.download_path(pointcloud_id,
                                                 pointcloud_file_path)

                    related_images_path = dataset_fs.get_related_images_path(
                        pointcloud_name)
                    related_images = api.pointcloud.get_list_related_images(
                        pointcloud_id)
                    for rimage_info in related_images:
                        name = rimage_info[ApiField.NAME]

                        if not has_valid_ext(name):
                            new_name = get_file_name(
                                name)  # to fix cases like .png.json
                            if has_valid_ext(new_name):
                                name = new_name
                                rimage_info[ApiField.NAME] = name
                            else:
                                raise RuntimeError(
                                    'Something wrong with photo context filenames.\
                                                    Please, contact support')

                        rimage_id = rimage_info[ApiField.ID]

                        path_img = os.path.join(related_images_path, name)
                        path_json = os.path.join(related_images_path,
                                                 name + ".json")

                        api.pointcloud.download_related_image(
                            rimage_id, path_img)
                        dump_json_file(rimage_info, path_json)

                else:
                    touch(pointcloud_file_path)

                dataset_fs.add_item_file(pointcloud_name,
                                         pointcloud_file_path,
                                         ann=PointcloudAnnotation.from_json(
                                             ann_json, project_fs.meta,
                                             key_id_map),
                                         _validate_item=False)

            ds_progress.iters_done_report(len(batch))

    project_fs.set_key_id_map(key_id_map)