Ejemplo n.º 1
0
    def run(self,
            template_path=None,
            data=None,
            state=None,
            initial_events=None):
        if template_path is None:
            # read config
            config_path = os.path.join(self.repo_dir,
                                       os.environ.get("CONFIG_DIR", ""),
                                       'config.json')
            if file_exists(config_path):
                #we are not in debug mode
                config = load_json_file(config_path)
                template_path = config.get('gui_template', None)
                if template_path is None:
                    self.logger.info(
                        "there is no gui_template field in config.json")
                else:
                    template_path = os.path.join(self.repo_dir, template_path)

            if template_path is None:
                template_path = os.path.join(os.path.dirname(sys.argv[0]),
                                             'gui.html')

        if not file_exists(template_path):
            self.logger.info("App will be running without GUI",
                             extra={"app_url": self.app_url})
            template = ""
        else:
            with open(template_path, 'r') as file:
                template = file.read()

        self.public_api.app.initialize(self.task_id, template, data, state)
        self.logger.info("Application session is initialized",
                         extra={"app_url": self.app_url})

        try:
            self.loop.create_task(self.publish(initial_events),
                                  name="Publisher")
            self.loop.create_task(self.consume(), name="Consumer")
            self.loop.run_forever()
        finally:
            self.loop.close()
            self.logger.info("Successfully shutdown the APP service.")

        if self._error is not None:
            raise self._error
Ejemplo n.º 2
0
 def get_related_images(self, item_name):
     results = []
     path = self.get_related_images_path(item_name)
     if dir_exists(path):
         files = list_files(path, SUPPORTED_IMG_EXTS)
         for file in files:
             img_meta_path = os.path.join(path, get_file_name_with_ext(file)+".json")
             img_meta = {}
             if file_exists(img_meta_path):
                 img_meta = load_json_file(img_meta_path)
             results.append((file, img_meta))
     return results
Ejemplo n.º 3
0
def read_validate_model_class_to_idx_map(model_config_fpath, in_project_classes_set, class_to_idx_config_key):
    """Reads class id --> int index mapping from the model config; checks that the set of classes matches the input."""

    if not fs.file_exists(model_config_fpath):
        raise RuntimeError('Unable to continue_training, config for previous training wasn\'t found.')

    with open(model_config_fpath) as fin:
        model_config = json.load(fin)

    model_class_mapping = model_config.get(class_to_idx_config_key, None)
    if model_class_mapping is None:
        raise RuntimeError('Unable to continue_training, model does not have class mapping information.')
    model_classes_set = set(model_class_mapping.keys())

    if model_classes_set != in_project_classes_set:
        error_message_text = 'Unable to continue_training, sets of classes for model and dataset do not match.'
        logger.critical(
            error_message_text, extra={'model_classes': model_classes_set, 'dataset_classes': in_project_classes_set})
        raise RuntimeError(error_message_text)
    return model_class_mapping.copy()
Ejemplo n.º 4
0
def load_font(font_file_name: str,
              font_size: int = 12) -> ImageFont.FreeTypeFont:
    """
    Set global font true-type for drawing.
    Args:
        font_file_name: name of font file (example: 'DejaVuSansMono.ttf')
        font_size: selected font size
    Returns:
        loaded from file font
    """
    if get_file_ext(font_file_name) == FONT_EXTENSION:
        font_path = _get_font_path_by_name(font_file_name)
        if (font_path is not None) and file_exists(font_path):
            return ImageFont.truetype(font_path, font_size, encoding='utf-8')
        else:
            raise ValueError(
                'Font file "{}" not found in system paths. Try to set another font.'
                .format(font_file_name))
    else:
        raise ValueError('Supported only TrueType fonts!')
def import_cityscapes(api: sly.Api, task_id, context, state, app_logger):
    tag_metas = sly.TagMetaCollection()
    obj_classes = sly.ObjClassCollection()
    dataset_names = []

    storage_dir = my_app.data_dir
    if INPUT_DIR:
        cur_files_path = INPUT_DIR
        extract_dir = os.path.join(
            storage_dir,
            str(Path(cur_files_path).parent).lstrip("/"))
        input_dir = os.path.join(extract_dir, Path(cur_files_path).name)
        archive_path = os.path.join(
            storage_dir,
            cur_files_path + ".tar")  # cur_files_path.split("/")[-2] + ".tar"
        project_name = Path(cur_files_path).name
    else:
        cur_files_path = INPUT_FILE
        extract_dir = os.path.join(storage_dir, get_file_name(cur_files_path))
        archive_path = os.path.join(storage_dir,
                                    get_file_name_with_ext(cur_files_path))
        project_name = get_file_name(INPUT_FILE)
        input_dir = os.path.join(storage_dir,
                                 get_file_name(cur_files_path))  # extract_dir
    api.file.download(TEAM_ID, cur_files_path, archive_path)
    if tarfile.is_tarfile(archive_path):
        with tarfile.open(archive_path) as archive:
            archive.extractall(extract_dir)
    else:
        raise Exception("No such file".format(INPUT_FILE))
    new_project = api.project.create(WORKSPACE_ID,
                                     project_name,
                                     change_name_if_conflict=True)
    tags_template = os.path.join(input_dir, "gtFine", "*")
    tags_paths = glob.glob(tags_template)
    tags = [os.path.basename(tag_path) for tag_path in tags_paths]
    if train_tag in tags and val_tag not in tags:
        split_train = True
    elif trainval_tag in tags and val_tag not in tags:
        split_train = True
    else:
        split_train = False
    search_fine = os.path.join(input_dir, "gtFine", "*", "*",
                               "*_gt*_polygons.json")
    files_fine = glob.glob(search_fine)
    files_fine.sort()
    search_imgs = os.path.join(input_dir, "leftImg8bit", "*", "*",
                               "*_leftImg8bit" + IMAGE_EXT)
    files_imgs = glob.glob(search_imgs)
    files_imgs.sort()
    if len(files_fine) == 0 or len(files_imgs) == 0:
        raise Exception('Input cityscapes format not correct')
    samples_count = len(files_fine)
    progress = sly.Progress('Importing images', samples_count)
    images_pathes_for_compare = []
    images_pathes = {}
    images_names = {}
    anns_data = {}
    ds_name_to_id = {}

    if samples_count > 2:
        random_train_indexes = get_split_idxs(samples_count, samplePercent)

    for idx, orig_ann_path in enumerate(files_fine):
        parent_dir, json_filename = os.path.split(
            os.path.abspath(orig_ann_path))
        dataset_name = os.path.basename(parent_dir)
        if dataset_name not in dataset_names:
            dataset_names.append(dataset_name)
            ds = api.dataset.create(new_project.id,
                                    dataset_name,
                                    change_name_if_conflict=True)
            ds_name_to_id[dataset_name] = ds.id
            images_pathes[dataset_name] = []
            images_names[dataset_name] = []
            anns_data[dataset_name] = []
        orig_img_path = json_path_to_image_path(orig_ann_path)
        images_pathes_for_compare.append(orig_img_path)
        if not file_exists(orig_img_path):
            logger.warn(
                'Image for annotation {} not found is dataset {}'.format(
                    orig_ann_path.split('/')[-1], dataset_name))
            continue
        images_pathes[dataset_name].append(orig_img_path)
        images_names[dataset_name].append(
            sly.io.fs.get_file_name_with_ext(orig_img_path))
        tag_path = os.path.split(parent_dir)[0]
        train_val_tag = os.path.basename(tag_path)
        if split_train is True and samples_count > 2:
            if (train_val_tag == train_tag) or (train_val_tag == trainval_tag):
                if idx in random_train_indexes:
                    train_val_tag = train_tag
                else:
                    train_val_tag = val_tag

        # tag_meta = sly.TagMeta(train_val_tag, sly.TagValueType.NONE)
        tag_meta = sly.TagMeta('split', sly.TagValueType.ANY_STRING)
        if not tag_metas.has_key(tag_meta.name):
            tag_metas = tag_metas.add(tag_meta)
        # tag = sly.Tag(tag_meta)
        tag = sly.Tag(meta=tag_meta, value=train_val_tag)
        json_data = json.load(open(orig_ann_path))
        ann = sly.Annotation.from_img_path(orig_img_path)
        for obj in json_data['objects']:
            class_name = obj['label']
            if class_name == 'out of roi':
                polygon = obj['polygon'][:5]
                interiors = [obj['polygon'][5:]]
            else:
                polygon = obj['polygon']
                if len(polygon) < 3:
                    logger.warn(
                        'Polygon must contain at least 3 points in ann {}, obj_class {}'
                        .format(orig_ann_path, class_name))
                    continue
                interiors = []
            interiors = [convert_points(interior) for interior in interiors]
            polygon = sly.Polygon(convert_points(polygon), interiors)
            if city_classes_to_colors.get(class_name, None):
                obj_class = sly.ObjClass(
                    name=class_name,
                    geometry_type=sly.Polygon,
                    color=city_classes_to_colors[class_name])
            else:
                new_color = generate_rgb(city_colors)
                city_colors.append(new_color)
                obj_class = sly.ObjClass(name=class_name,
                                         geometry_type=sly.Polygon,
                                         color=new_color)
            ann = ann.add_label(sly.Label(polygon, obj_class))
            if not obj_classes.has_key(class_name):
                obj_classes = obj_classes.add(obj_class)
        ann = ann.add_tag(tag)
        anns_data[dataset_name].append(ann)
        progress.iter_done_report()
    out_meta = sly.ProjectMeta(obj_classes=obj_classes, tag_metas=tag_metas)
    api.project.update_meta(new_project.id, out_meta.to_json())

    for ds_name, ds_id in ds_name_to_id.items():
        dst_image_infos = api.image.upload_paths(ds_id, images_names[ds_name],
                                                 images_pathes[ds_name])
        dst_image_ids = [img_info.id for img_info in dst_image_infos]
        api.annotation.upload_anns(dst_image_ids, anns_data[ds_name])

    stat_dct = {
        'samples': samples_count,
        'src_ann_cnt': len(files_fine),
        'src_img_cnt': len(files_imgs)
    }
    logger.info('Found img/ann pairs.', extra=stat_dct)
    images_without_anns = set(files_imgs) - set(images_pathes_for_compare)
    if len(images_without_anns) > 0:
        logger.warn('Found source images without corresponding annotations:')
        for im_path in images_without_anns:
            logger.warn('Annotation not found {}'.format(im_path))

    logger.info('Found classes.',
                extra={
                    'cnt':
                    len(obj_classes),
                    'classes':
                    sorted([obj_class.name for obj_class in obj_classes])
                })
    logger.info('Created tags.',
                extra={
                    'cnt':
                    len(out_meta.tag_metas),
                    'tags':
                    sorted([tag_meta.name for tag_meta in out_meta.tag_metas])
                })
    my_app.stop()