obj_class if (obj_class.geometry_type == sly.Rectangle) else sly.ObjClass(obj_class.name + '_bbox', sly.Rectangle, color=obj_class.color)) for obj_class in src_project.meta.obj_classes} dst_meta = src_project.meta.clone( obj_classes=sly.ObjClassCollection(bbox_class_mapping.values()), tag_metas=src_project.meta.tag_metas.add_items([tag_meta_train, tag_meta_val])) dst_project.set_meta(dst_meta) crop_side_fraction = (min_crop_side_fraction, max_crop_side_fraction) total_images = api.project.get_images_count(src_project_info.id) if total_images <= 1: raise RuntimeError('Need at least 2 images in a project to prepare a training set (at least 1 each for training ' 'and validation).') is_train_image = sly_dataset.partition_train_val(total_images, validation_fraction) # Iterate over datasets and items. image_idx = 0 for dataset in src_project: sly.logger.info('Dataset processing', extra={'dataset_name': dataset.name}) dst_dataset = dst_project.create_dataset(dataset.name) for item_name in dataset: item_paths = dataset.get_item_paths(item_name) img = sly.image.read(item_paths.img_path) ann = sly.Annotation.load_json_file(item_paths.ann_path, src_project.meta) # Decide whether this image and its crops should go to a train or validation fold. tag = sly.Tag(tag_meta_train) if is_train_image[image_idx] else sly.Tag(tag_meta_val) ann = ann.add_tag(tag)
tag_meta_val = sly.TagMeta(val_tag_name, sly.TagValueType.NONE) dst_meta = src_meta.add_tag_metas([tag_meta_train, tag_meta_val]) # Will choose a new name if dst_project_name is already taken. dst_project = api.project.create(WORKSPACE_ID, dst_project_name, change_name_if_conflict=True) api.project.update_meta(dst_project.id, dst_meta.to_json()) src_dataset_infos = ( [api.dataset.get_info_by_id(ds_id) for ds_id in src_dataset_ids] if (src_dataset_ids is not None) else api.dataset.get_list(src_project.id)) total_images = sum(ds_info.images_count for ds_info in src_dataset_infos) if total_images <= 1: raise RuntimeError('Need at least 2 images in a project to prepare a training set (at least 1 each for training ' 'and validation).') is_train_image = partition_train_val(total_images, validation_fraction) batch_start_idx = 0 for src_dataset in src_dataset_infos: dst_dataset = api.dataset.create(dst_project.id, src_dataset.name, src_dataset.description) images = api.image.get_list(src_dataset.id) ds_progress = sly.Progress( 'Tagging dataset: {!r}/{!r}'.format(src_project.name, src_dataset.name), total_cnt=len(images)) for batch in sly.batched(images): image_ids = [image_info.id for image_info in batch] image_names = [image_info.name for image_info in batch] ann_infos = api.annotation.download_batch(src_dataset.id, image_ids) src_anns = [sly.Annotation.from_json(ann_info.annotation, dst_meta) for ann_info in ann_infos] anns_tagged = [ann.add_tag(sly.Tag(tag_meta_train) if is_train_image[image_idx] else sly.Tag(tag_meta_val)) for image_idx, ann in enumerate(src_anns, start=batch_start_idx)]