示例#1
0
        obj_class if (obj_class.geometry_type == sly.Rectangle)
        else sly.ObjClass(obj_class.name + '_bbox', sly.Rectangle, color=obj_class.color))
    for obj_class in src_project.meta.obj_classes}

dst_meta = src_project.meta.clone(
    obj_classes=sly.ObjClassCollection(bbox_class_mapping.values()),
    tag_metas=src_project.meta.tag_metas.add_items([tag_meta_train, tag_meta_val]))
dst_project.set_meta(dst_meta)

crop_side_fraction = (min_crop_side_fraction, max_crop_side_fraction)

total_images = api.project.get_images_count(src_project_info.id)
if total_images <= 1:
    raise RuntimeError('Need at least 2 images in a project to prepare a training set (at least 1 each for training '
                       'and validation).')
is_train_image = sly_dataset.partition_train_val(total_images, validation_fraction)

# Iterate over datasets and items.
image_idx = 0
for dataset in src_project:
    sly.logger.info('Dataset processing', extra={'dataset_name': dataset.name})
    dst_dataset = dst_project.create_dataset(dataset.name)

    for item_name in dataset:
        item_paths = dataset.get_item_paths(item_name)
        img = sly.image.read(item_paths.img_path)
        ann = sly.Annotation.load_json_file(item_paths.ann_path, src_project.meta)

        # Decide whether this image and its crops should go to a train or validation fold.
        tag = sly.Tag(tag_meta_train) if is_train_image[image_idx] else sly.Tag(tag_meta_val)
        ann = ann.add_tag(tag)
示例#2
0
tag_meta_val = sly.TagMeta(val_tag_name, sly.TagValueType.NONE)
dst_meta = src_meta.add_tag_metas([tag_meta_train, tag_meta_val])

# Will choose a new name if dst_project_name is already taken.
dst_project = api.project.create(WORKSPACE_ID, dst_project_name, change_name_if_conflict=True)
api.project.update_meta(dst_project.id, dst_meta.to_json())

src_dataset_infos = (
    [api.dataset.get_info_by_id(ds_id) for ds_id in src_dataset_ids] if (src_dataset_ids is not None)
    else api.dataset.get_list(src_project.id))
total_images = sum(ds_info.images_count for ds_info in src_dataset_infos)

if total_images <= 1:
    raise RuntimeError('Need at least 2 images in a project to prepare a training set (at least 1 each for training '
                       'and validation).')
is_train_image = partition_train_val(total_images, validation_fraction)

batch_start_idx = 0
for src_dataset in src_dataset_infos:
    dst_dataset = api.dataset.create(dst_project.id, src_dataset.name, src_dataset.description)
    images = api.image.get_list(src_dataset.id)
    ds_progress = sly.Progress(
        'Tagging dataset: {!r}/{!r}'.format(src_project.name, src_dataset.name), total_cnt=len(images))
    for batch in sly.batched(images):
        image_ids = [image_info.id for image_info in batch]
        image_names = [image_info.name for image_info in batch]

        ann_infos = api.annotation.download_batch(src_dataset.id, image_ids)
        src_anns = [sly.Annotation.from_json(ann_info.annotation, dst_meta) for ann_info in ann_infos]
        anns_tagged = [ann.add_tag(sly.Tag(tag_meta_train) if is_train_image[image_idx] else sly.Tag(tag_meta_val))
                       for image_idx, ann in enumerate(src_anns, start=batch_start_idx)]