Exemplo n.º 1
0
class DatasetApi(ModuleApi):
    _info_sequence = [
        ApiField.ID, ApiField.NAME, ApiField.DESCRIPTION, ApiField.SIZE,
        ApiField.PROJECT_ID, ApiField.IMAGES_COUNT, ApiField.CREATED_AT,
        ApiField.UPDATED_AT
    ]
    Info = namedtuple('DatasetInfo',
                      [camel_to_snake(name) for name in _info_sequence])

    def get_list(self, project_id, filters=None):
        return self.get_list_all_pages('datasets.list', {
            ApiField.PROJECT_ID: project_id,
            ApiField.FILTER: filters or []
        })

    def get_info_by_id(self, id):
        return self._get_info_by_id(id, 'datasets.info')

    def create(self, project_id, name, description=""):
        response = self.api.post(
            'datasets.add', {
                ApiField.PROJECT_ID: project_id,
                ApiField.NAME: name,
                ApiField.DESCRIPTION: description
            })
        return self._convert_json_info(response.json())

    def _get_update_method(self):
        return 'datasets.editInfo'
class AgentApi(ModuleApi):
    class Status(Enum):
        WAITING = 'waiting'
        RUNNING = 'running'

    _info_sequence = [
        ApiField.ID, ApiField.NAME, ApiField.TOKEN, ApiField.STATUS,
        ApiField.USER_ID, ApiField.TEAM_ID, ApiField.CAPABILITIES,
        ApiField.CREATED_AT, ApiField.UPDATED_AT
    ]
    Info = namedtuple('AgentInfo',
                      [camel_to_snake(name) for name in _info_sequence])

    def get_list(self, team_id, filters=None):
        return self.get_list_all_pages('agents.list', {
            'teamId': team_id,
            "filter": filters or []
        })

    def get_info_by_id(self, id):
        return self._get_info_by_id(id, 'agent.info')

    def get_status(self, id):
        status_str = self.get_info_by_id(id).status
        return self.Status(status_str)

    def raise_for_status(self, status):
        pass
Exemplo n.º 3
0
class WorkspaceApi(ModuleApi):
    _info_sequence = [
        ApiField.ID, ApiField.NAME, ApiField.DESCRIPTION, ApiField.TEAM_ID,
        ApiField.CREATED_AT, ApiField.UPDATED_AT
    ]
    Info = namedtuple('WorkspaceInfo',
                      [camel_to_snake(name) for name in _info_sequence])

    def get_list(self, team_id, filters=None):
        return self.get_list_all_pages('workspaces.list', {
            ApiField.TEAM_ID: team_id,
            ApiField.FILTER: filters or []
        })

    def get_info_by_id(self, id):
        return self._get_info_by_id(id, 'workspaces.info')

    def create(self, team_id, name, description=""):
        response = self.api.post(
            'workspaces.add', {
                ApiField.TEAM_ID: team_id,
                ApiField.NAME: name,
                ApiField.DESCRIPTION: description
            })
        return self._convert_json_info(response.json())

    def _get_update_method(self):
        return 'workspaces.editInfo'
Exemplo n.º 4
0
 def __init_subclass__(cls, **kwargs):
     super().__init_subclass__(**kwargs)
     try:
         cls.InfoType = namedtuple(
             cls.info_tuple_name(),
             [camel_to_snake(name) for name in cls.info_sequence()])
     except NotImplementedError:
         pass
Exemplo n.º 5
0
 def __init_subclass__(cls, **kwargs):
     super().__init_subclass__(**kwargs)
     try:
         field_names = []
         for name in cls.info_sequence():
             if type(name) is str:
                 field_names.append(camel_to_snake(name))
             elif type(name) is tuple and type(name[1]) is str:
                 field_names.append(name[1])
             else:
                 raise RuntimeError('Can not parse field {!r}'.format(name))
         cls.InfoType = namedtuple(cls.info_tuple_name(), field_names)
     except NotImplementedError:
         pass
class ProjectApi(ModuleApi):
    _info_sequence = [
        ApiField.ID, ApiField.NAME, ApiField.DESCRIPTION, ApiField.SIZE,
        ApiField.README, ApiField.WORKSPACE_ID, ApiField.CREATED_AT,
        ApiField.UPDATED_AT
    ]
    Info = namedtuple('ProjectInfo',
                      [camel_to_snake(name) for name in _info_sequence])

    def get_list(self, workspace_id, filters=None):
        return self.get_list_all_pages('projects.list', {
            ApiField.WORKSPACE_ID: workspace_id,
            "filter": filters or []
        })

    def get_info_by_id(self, id):
        return self._get_info_by_id(id, 'projects.info')

    def get_meta(self, id):
        response = self.api.post('projects.meta', {'id': id})
        return response.json()

    def create(self, workspace_id, name, description=""):
        response = self.api.post(
            'projects.add', {
                ApiField.WORKSPACE_ID: workspace_id,
                ApiField.NAME: name,
                ApiField.DESCRIPTION: description
            })
        return self._convert_json_info(response.json())

    def _get_update_method(self):
        return 'projects.editInfo'

    def update_meta(self, id, meta):
        self.api.post('projects.meta.update', {
            ApiField.ID: id,
            ApiField.META: meta
        })

    def _clone_api_method_name(self):
        return 'projects.clone'

    def get_datasets_count(self, id):
        datasets = self.api.dataset.get_list(id)
        return len(datasets)

    def get_images_count(self, id):
        datasets = self.api.dataset.get_list(id)
        return sum([dataset.images_count for dataset in datasets])
Exemplo n.º 7
0
class PluginApi(ModuleApi):
    _info_sequence = [ApiField.ID,
                      ApiField.NAME,
                      ApiField.DESCRIPTION,
                      ApiField.TYPE,
                      ApiField.DEFAULT_VERSION,
                      ApiField.DOCKER_IMAGE,
                      ApiField.README,
                      ApiField.CONFIGS,
                      ApiField.VERSIONS,
                      ApiField.CREATED_AT,
                      ApiField.UPDATED_AT]
    Info = namedtuple('PluginInfo', [camel_to_snake(name) for name in _info_sequence])

    def get_list(self, team_id, filters=None):
        return self.get_list_all_pages('plugins.list',  {ApiField.TEAM_ID: team_id, ApiField.FILTER: filters or []})

    def get_info_by_id(self, team_id, plugin_id):
        filters = [{"field": ApiField.ID, "operator": "=", "value": plugin_id}]
        return self._get_info_by_filters(team_id, filters)
class AnnotationApi(ModuleApi):
    _info_sequence = [ApiField.IMAGE_ID,
                      ApiField.IMAGE_NAME,
                      ApiField.ANNOTATION,
                      ApiField.CREATED_AT,
                      ApiField.UPDATED_AT]
    Info = namedtuple('AnnotationInfo', [camel_to_snake(name) for name in _info_sequence])

    def get_list(self, dataset_id, filters=None, progress_cb=None):
        return self.get_list_all_pages('annotations.list',  {ApiField.DATASET_ID: dataset_id, ApiField.FILTER: filters or []}, progress_cb)

    def download(self, image_id):
        response = self.api.post('annotations.info', {ApiField.IMAGE_ID: image_id})
        return self._convert_json_info(response.json())

    def download_batch(self, dataset_id, image_ids, progress_cb=None):
        id_to_ann = {}
        for batch in batched(image_ids):
            results = self.api.post('annotations.bulk.info', data={ApiField.DATASET_ID: dataset_id, ApiField.IMAGE_IDS: batch}).json()
            for ann_dict in results:
                ann_info = self._convert_json_info(ann_dict)
                id_to_ann[ann_info.image_id] = ann_info
            if progress_cb is not None:
                progress_cb(len(batch))
        ordered_results = [id_to_ann[image_id] for image_id in image_ids]
        return ordered_results

    def upload_path(self, img_id, ann_path):
        self.upload_paths([img_id], [ann_path])

    def upload_paths(self, img_ids, ann_paths, progress_cb=None):
        # img_ids from the same dataset
        def read_json(ann_path):
            with open(ann_path) as json_file:
                return json.load(json_file)
        self._upload_batch(read_json, img_ids, ann_paths, progress_cb)

    def upload_json(self, img_id, ann_json):
        self.upload_jsons([img_id], [ann_json])

    def upload_jsons(self, img_ids, ann_jsons, progress_cb=None):
        # img_ids from the same dataset
        self._upload_batch(lambda x : x, img_ids, ann_jsons, progress_cb)

    def upload_ann(self, img_id, ann):
        self.upload_anns([img_id], [ann])

    def upload_anns(self, img_ids, anns, progress_cb=None):
        # img_ids from the same dataset
        self._upload_batch(Annotation.to_json, img_ids, anns, progress_cb)

    def _upload_batch(self, func_ann_to_json, img_ids, anns, progress_cb=None):
        # img_ids from the same dataset
        if len(img_ids) == 0:
            return
        if len(img_ids) != len(anns):
            raise RuntimeError("Can not match \"img_ids\" and \"anns\" lists, len(img_ids) != len(anns)")

        dataset_id = self.api.image.get_info_by_id(img_ids[0]).dataset_id
        for batch in batched(list(zip(img_ids, anns))):
            data = [{ApiField.IMAGE_ID: img_id, ApiField.ANNOTATION: func_ann_to_json(ann)} for img_id, ann in batch]
            self.api.post('annotations.bulk.add', data={ApiField.DATASET_ID: dataset_id, ApiField.ANNOTATIONS: data})
            if progress_cb is not None:
                progress_cb(len(batch))

    def get_info_by_id(self, id):
        raise RuntimeError('Method is not supported')

    def get_info_by_name(self, parent_id, name):
        raise RuntimeError('Method is not supported')

    def exists(self, parent_id, name):
        raise RuntimeError('Method is not supported')

    def get_free_name(self, parent_id, name):
        raise RuntimeError('Method is not supported')

    def _add_sort_param(self, data):
        return data
Exemplo n.º 9
0
class ImageApi(ModuleApi):
    _info_sequence = [
        ApiField.ID, ApiField.NAME, ApiField.LINK, ApiField.HASH,
        ApiField.MIME, ApiField.EXT, ApiField.SIZE, ApiField.WIDTH,
        ApiField.HEIGHT, ApiField.LABELS_COUNT, ApiField.DATASET_ID,
        ApiField.CREATED_AT, ApiField.UPDATED_AT
    ]

    Info = namedtuple('ImageInfo',
                      [camel_to_snake(name) for name in _info_sequence])

    def get_list(self, dataset_id, filters=None):
        return self.get_list_all_pages('images.list', {
            ApiField.DATASET_ID: dataset_id,
            ApiField.FILTER: filters or []
        })

    def get_info_by_id(self, id):
        return self._get_info_by_id(id, 'images.info')

    # @TODO: reimplement to new method images.bulk.info
    def get_info_by_id_batch(self, ids):
        results = []
        if len(ids) == 0:
            return results
        dataset_id = self.get_info_by_id(ids[0]).dataset_id
        for batch in batched(ids):
            filters = [{
                "field": ApiField.ID,
                "operator": "in",
                "value": batch
            }]
            results.extend(
                self.get_list_all_pages('images.list', {
                    ApiField.DATASET_ID: dataset_id,
                    ApiField.FILTER: filters
                }))
        return results

    def _download(self, id):
        response = self.api.post('images.download', {ApiField.ID: id})
        return response

    def download_np(self, id):
        response = self._download(id)
        img = sly_image.read_bytes(response.content)
        return img

    def download_path(self, id, path):
        response = self._download(id)
        ensure_base_path(path)
        with open(path, 'wb') as fd:
            for chunk in response.iter_content(chunk_size=1024 * 1024):
                fd.write(chunk)

    def _download_batch(self, dataset_id, ids):
        for batch_ids in batched(ids):
            response = self.api.post('images.bulk.download', {
                ApiField.DATASET_ID: dataset_id,
                ApiField.IMAGE_IDS: batch_ids
            })
            decoder = MultipartDecoder.from_response(response)
            for idx, part in enumerate(decoder.parts):
                img_id = int(
                    re.findall(
                        'name="(.*)"',
                        part.headers[b'Content-Disposition'].decode('utf-8'))
                    [0])
                yield img_id, part

    def download_paths(self, dataset_id, ids, paths, progress_cb=None):
        if len(ids) == 0:
            return
        if len(ids) != len(paths):
            raise RuntimeError(
                "Can not match \"ids\" and \"paths\" lists, len(ids) != len(paths)"
            )

        id_to_path = {id: path for id, path in zip(ids, paths)}
        #debug_ids = []
        for img_id, resp_part in self._download_batch(dataset_id, ids):
            #debug_ids.append(img_id)
            with open(id_to_path[img_id], 'wb') as w:
                w.write(resp_part.content)
            if progress_cb is not None:
                progress_cb(1)
        #if ids != debug_ids:
        #    raise RuntimeError("images.bulk.download: imageIds order is broken")

    def download_nps(self, dataset_id, ids, progress_cb=None):
        images = []
        if len(ids) == 0:
            return images

        id_to_img = {}
        for img_id, resp_part in self._download_batch(dataset_id, ids):
            id_to_img[img_id] = sly_image.read_bytes(resp_part.content)
            if progress_cb is not None:
                progress_cb(1)

        images = [id_to_img[id] for id in ids]
        return images

    def check_existing_hashes(self, hashes):
        results = []
        if len(hashes) == 0:
            return results
        for hashes_batch in batched(hashes, batch_size=900):
            response = self.api.post('images.internal.hashes.list',
                                     hashes_batch)
            results.extend(response.json())
        return results

    def check_image_uploaded(self, hash):
        response = self.api.post('images.internal.hashes.list', [hash])
        results = response.json()
        if len(results) == 0:
            return False
        else:
            return True

    def _upload_data_bulk(self, func_item_to_byte_stream, func_item_hash,
                          items, progress_cb):
        hashes = []
        if len(items) == 0:
            return hashes

        hash_to_items = defaultdict(list)

        for idx, item in enumerate(items):
            item_hash = func_item_hash(item)
            hashes.append(item_hash)
            hash_to_items[item_hash].append(item)

        unique_hashes = set(hashes)
        remote_hashes = self.check_existing_hashes(list(unique_hashes))
        new_hashes = unique_hashes - set(remote_hashes)

        if progress_cb is not None:
            progress_cb(len(remote_hashes))

        # upload only new images to supervisely server
        items_to_upload = []
        for hash in new_hashes:
            items_to_upload.extend(hash_to_items[hash])

        for batch in batched(items_to_upload):
            content_dict = {}
            for idx, item in enumerate(batch):
                content_dict["{}-file".format(idx)] = (
                    str(idx), func_item_to_byte_stream(item), 'image/*')
            encoder = MultipartEncoder(fields=content_dict)
            self.api.post('images.bulk.upload', encoder)
            if progress_cb is not None:
                progress_cb(len(batch))

        return hashes

    def upload_path(self, dataset_id, name, path):
        return self.upload_paths(dataset_id, [name], [path])[0]

    def upload_paths(self, dataset_id, names, paths, progress_cb=None):
        def path_to_bytes_stream(path):
            return open(path, 'rb')

        hashes = self._upload_data_bulk(path_to_bytes_stream, get_file_hash,
                                        paths, progress_cb)
        return self.upload_hashes(dataset_id, names, hashes)

    def upload_np(self, dataset_id, name, img):
        return self.upload_nps(dataset_id, [name], [img])[0]

    def upload_nps(self, dataset_id, names, imgs, progress_cb=None):
        def img_to_bytes_stream(item):
            img, name = item[0], item[1]
            img_bytes = sly_image.write_bytes(img, get_file_ext(name))
            return io.BytesIO(img_bytes)

        def img_to_hash(item):
            img, name = item[0], item[1]
            return sly_image.get_hash(img, get_file_ext(name))

        img_name_list = list(zip(imgs, names))
        hashes = self._upload_data_bulk(img_to_bytes_stream, img_to_hash,
                                        img_name_list, progress_cb)
        return self.upload_hashes(dataset_id, names, hashes)

    def upload_link(self, dataset_id, name, link):
        return self.upload_links(dataset_id, [name], [link])[0]

    def upload_links(self, dataset_id, names, links, progress_cb=None):
        return self._upload_bulk_add(lambda item: (ApiField.LINK, item),
                                     dataset_id, names, links, progress_cb)

    def upload_hash(self, dataset_id, name, hash):
        return self.upload_hashes(dataset_id, [name], [hash])[0]

    def upload_hashes(self, dataset_id, names, hashes, progress_cb=None):
        return self._upload_bulk_add(lambda item: (ApiField.HASH, item),
                                     dataset_id, names, hashes, progress_cb)

    def upload_id(self, dataset_id, name, id):
        return self.upload_ids(dataset_id, [name], [id])[0]

    def upload_ids(self, dataset_id, names, ids, progress_cb=None):
        # all ids have to be from single dataset
        infos = self.get_info_by_id_batch(ids)
        hashes = [info.hash for info in infos]
        return self.upload_hashes(dataset_id, names, hashes, progress_cb)

    def _upload_bulk_add(self,
                         func_item_to_kv,
                         dataset_id,
                         names,
                         items,
                         progress_cb=None):
        results = []

        if len(names) == 0:
            return results
        if len(names) != len(items):
            raise RuntimeError(
                "Can not match \"names\" and \"items\" lists, len(names) != len(items)"
            )

        for batch in batched(list(zip(names, items))):
            images = []
            for name, item in batch:
                item_tuple = func_item_to_kv(item)
                #@TODO: 'title' -> ApiField.NAME
                images.append({'title': name, item_tuple[0]: item_tuple[1]})
            response = self.api.post('images.bulk.add', {
                ApiField.DATASET_ID: dataset_id,
                ApiField.IMAGES: images
            })
            if progress_cb is not None:
                progress_cb(len(images))
            results.extend([
                self._convert_json_info(info_json)
                for info_json in response.json()
            ])

        name_to_res = {img_info.name: img_info for img_info in results}
        ordered_results = [name_to_res[name] for name in names]

        return ordered_results

    #@TODO: reimplement
    def _convert_json_info(self, info: dict):
        if info is None:
            return None
        temp_ext = None
        field_values = []
        for field_name in self.__class__._info_sequence:
            if field_name == ApiField.EXT:
                continue
            field_values.append(info[field_name])
            if field_name == ApiField.MIME:
                temp_ext = info[field_name].split('/')[1]
                field_values.append(temp_ext)
        for idx, field_name in enumerate(self.__class__._info_sequence):
            if field_name == ApiField.NAME:
                cur_ext = get_file_ext(field_values[idx])
                if not cur_ext:
                    field_values[idx] = "{}.{}".format(field_values[idx],
                                                       temp_ext)
                    break

                cur_ext = cur_ext.replace(".", "").lower()
                if temp_ext == 'jpeg' and cur_ext in ['jpg', 'jpeg']:
                    break

                if temp_ext not in field_values[idx]:
                    field_values[idx] = "{}.{}".format(field_values[idx],
                                                       temp_ext)
                break
        return self.__class__.Info._make(field_values)
Exemplo n.º 10
0
class AnnotationApi(ModuleApi):
    _info_sequence = [
        ApiField.IMAGE_ID, ApiField.IMAGE_NAME, ApiField.ANNOTATION,
        ApiField.CREATED_AT, ApiField.UPDATED_AT
    ]
    Info = namedtuple('AnnotationInfo',
                      [camel_to_snake(name) for name in _info_sequence])

    def get_list(self, dataset_id, filters=None, progress_cb=None):
        return self.get_list_all_pages('annotations.list', {
            ApiField.DATASET_ID: dataset_id,
            ApiField.FILTER: filters or []
        }, progress_cb)

    def download(self, image_id):
        response = self.api.post('annotations.info',
                                 {ApiField.IMAGE_ID: image_id})
        return self._convert_json_info(response.json())

    def download_batch(self, dataset_id, image_ids, progress_cb=None):
        filters = [{
            "field": ApiField.IMAGE_ID,
            "operator": "in",
            "value": image_ids
        }]
        return self.get_list_all_pages('annotations.list', {
            ApiField.DATASET_ID: dataset_id,
            ApiField.FILTER: filters
        }, progress_cb)

    # @TODO: no errors from api if annotation is not valid
    def upload(self, image_id: int, ann: dict):
        self.api.post('annotations.add',
                      data={
                          ApiField.IMAGE_ID: image_id,
                          ApiField.ANNOTATION: ann
                      })

    def upload_batch_paths(self,
                           dataset_id,
                           img_ids,
                           ann_paths,
                           progress_cb=None):
        MAX_BATCH_SIZE = 50
        for batch in batched(list(zip(img_ids, ann_paths)), MAX_BATCH_SIZE):
            data = []
            for img_id, ann_path in batch:
                with open(ann_path) as json_file:
                    ann_json = json.load(json_file)
                data.append({
                    ApiField.IMAGE_ID: img_id,
                    ApiField.ANNOTATION: ann_json
                })
            self.api.post('annotations.bulk.add',
                          data={
                              ApiField.DATASET_ID: dataset_id,
                              ApiField.ANNOTATIONS: data
                          })
            if progress_cb is not None:
                progress_cb(len(batch))

    def upload_batch(self, img_ids, anns, progress_cb=None):
        raise NotImplementedError()

    def get_info_by_id(self, id):
        raise RuntimeError('Method is not supported')

    def get_info_by_name(self, parent_id, name):
        raise RuntimeError('Method is not supported')

    def exists(self, parent_id, name):
        raise RuntimeError('Method is not supported')

    def get_free_name(self, parent_id, name):
        raise RuntimeError('Method is not supported')

    def _add_sort_param(self, data):
        return data
class NeuralNetworkApi(ModuleApi):
    _info_sequence = [
        ApiField.ID, ApiField.NAME, ApiField.DESCRIPTION, ApiField.CONFIG,
        ApiField.HASH, ApiField.ONLY_TRAIN, ApiField.PLUGIN_ID,
        ApiField.PLUGIN_VERSION, ApiField.SIZE, ApiField.WEIGHTS_LOCATION,
        ApiField.README, ApiField.TASK_ID, ApiField.USER_ID, ApiField.TEAM_ID,
        ApiField.WORKSPACE_ID, ApiField.CREATED_AT, ApiField.UPDATED_AT
    ]
    Info = namedtuple('ModelInfo',
                      [camel_to_snake(name) for name in _info_sequence])

    def get_list(self, workspace_id, filters=None):
        return self.get_list_all_pages('models.list', {
            ApiField.WORKSPACE_ID: workspace_id,
            ApiField.FILTER: filters or []
        })

    def get_info_by_id(self, id):
        return self._get_info_by_id(id, 'models.info')

    def download(self, id):
        response = self.api.post('models.download', {ApiField.ID: id},
                                 stream=True)
        return response

    def download_to_tar(self, workspace_id, name, tar_path, progress_cb=None):
        model = self.get_info_by_name(workspace_id, name)
        response = self.download(model.id)
        ensure_base_path(tar_path)
        with open(tar_path, 'wb') as fd:
            for chunk in response.iter_content(chunk_size=1024 * 1024):
                fd.write(chunk)
                if progress_cb is not None:
                    read_mb = len(chunk) / 1024.0 / 1024.0
                    progress_cb(read_mb)

    def download_to_dir(self, workspace_id, name, directory, progress_cb=None):
        model_tar = os.path.join(directory, rand_str(10) + '.tar')
        self.download_to_tar(workspace_id, name, model_tar, progress_cb)
        model_dir = os.path.join(directory, name)
        with tarfile.open(model_tar) as archive:
            archive.extractall(model_dir)
        silent_remove(model_tar)
        return model_dir

    def generate_hash(self, task_id):
        response = self.api.post('models.hash.create',
                                 {ApiField.TASK_ID: task_id})
        return response.json()

    def upload(self, hash, archive_path, progress_cb=None):
        encoder = MultipartEncoder({
            'hash':
            hash,
            'weights':
            (os.path.basename(archive_path), open(archive_path,
                                                  'rb'), 'application/x-tar')
        })

        def callback(monitor):
            read_mb = monitor.bytes_read / 1024.0 / 1024.0
            if progress_cb is not None:
                progress_cb(read_mb)

        monitor = MultipartEncoderMonitor(encoder, callback)
        self.api.post('models.upload', monitor)

    def inference_remote_image(self,
                               id,
                               image_hash,
                               ann=None,
                               meta=None,
                               mode=None,
                               ext=None):
        data = {
            "request_type": "inference",
            "meta": meta or ProjectMeta().to_json(),
            "annotation": ann or None,
            "mode": mode or {},
            "image_hash": image_hash
        }
        fake_img_data = sly_image.write_bytes(np.zeros([5, 5, 3]), '.jpg')
        encoder = MultipartEncoder({
            'id': str(id).encode('utf-8'),
            'data': json.dumps(data),
            'image': ("img", fake_img_data, "")
        })
        response = self.api.post('models.infer',
                                 MultipartEncoderMonitor(encoder))
        return response.json()

    def inference(self, id, img, ann=None, meta=None, mode=None, ext=None):
        data = {
            "request_type": "inference",
            "meta": meta or ProjectMeta().to_json(),
            "annotation": ann or None,
            "mode": mode or {},
        }
        img_data = sly_image.write_bytes(img, ext or '.jpg')
        encoder = MultipartEncoder({
            'id': str(id).encode('utf-8'),
            'data': json.dumps(data),
            'image': ("img", img_data, "")
        })

        response = self.api.post('models.infer',
                                 MultipartEncoderMonitor(encoder))
        return response.json()

    def get_output_meta(self, id, input_meta=None, inference_mode=None):
        data = {
            "request_type": "get_out_meta",
            "meta": input_meta or ProjectMeta().to_json(),
            "mode": inference_mode or {}
        }
        encoder = MultipartEncoder({
            'id': str(id).encode('utf-8'),
            'data': json.dumps(data)
        })
        response = self.api.post('models.infer',
                                 MultipartEncoderMonitor(encoder))
        response_json = response.json()
        if 'out_meta' in response_json:
            return response_json['out_meta']
        return response.json()

    def get_deploy_tasks(self, model_id):
        response = self.api.post('models.info.deployed', {'id': model_id})
        return [task[ApiField.ID] for task in response.json()]

    def _clone_api_method_name(self):
        return 'models.clone'
Exemplo n.º 12
0
class ImageApi(ModuleApi):
    _info_sequence = [
        ApiField.ID, ApiField.NAME, ApiField.LINK, ApiField.HASH,
        ApiField.MIME, ApiField.EXT, ApiField.SIZE, ApiField.WIDTH,
        ApiField.HEIGHT, ApiField.LABELS_COUNT, ApiField.DATASET_ID,
        ApiField.CREATED_AT, ApiField.UPDATED_AT
    ]

    Info = namedtuple('ImageInfo',
                      [camel_to_snake(name) for name in _info_sequence])

    def get_list(self, dataset_id, filters=None):
        return self.get_list_all_pages('images.list', {
            ApiField.DATASET_ID: dataset_id,
            ApiField.FILTER: filters or []
        })

    def get_info_by_id(self, id):
        return self._get_info_by_id(id, 'images.info')

    def download(self, id):
        response = self.api.post('images.download', {ApiField.ID: id})
        return response

    def download_np(self, id):
        response = self.download(id)
        image_bytes = np.asarray(bytearray(response.content), dtype="uint8")
        img = image.read_bytes(image_bytes)
        return img

    def download_to_file(self, id, path):
        response = self.download(id)
        ensure_base_path(path)
        with open(path, 'wb') as fd:
            for chunk in response.iter_content(chunk_size=1024 * 1024):
                fd.write(chunk)

    def download_batch(self, dataset_id, ids, paths, progress_cb=None):
        id_to_path = {id: path for id, path in zip(ids, paths)}
        MAX_BATCH_SIZE = 50
        for batch_ids in batched(ids, MAX_BATCH_SIZE):
            response = self.api.post('images.bulk.download', {
                ApiField.DATASET_ID: dataset_id,
                ApiField.IMAGE_IDS: batch_ids
            })
            decoder = MultipartDecoder.from_response(response)
            for idx, part in enumerate(decoder.parts):
                img_id = int(
                    re.findall(
                        'name="(.*)"',
                        part.headers[b'Content-Disposition'].decode('utf-8'))
                    [0])
                with open(id_to_path[img_id], 'wb') as w:
                    w.write(part.content)
                progress_cb(1)

    def check_existing_hashes(self, hashes):
        BATCH_SIZE = 900
        results = []
        for hashes_batch in batched(hashes, BATCH_SIZE):
            response = self.api.post('images.internal.hashes.list',
                                     hashes_batch)
            results.extend(response.json())
        return results

    def check_image_uploaded(self, hash):
        response = self.api.post('images.internal.hashes.list', [hash])
        results = response.json()
        if len(results) == 0:
            return False
        else:
            return True

    def upload_np(self, img_np, ext='.png'):
        data = image.write_bytes(img_np, ext)
        return self.upload(data)

    def upload_path(self, img_path):
        data = open(img_path, 'rb').read()
        return self.upload(data)

    def upload(self, data):
        response = self.api.post('images.upload', data)
        return response.json()[ApiField.HASH]

    def upload_link(self, link):
        response = self.api.post('images.remote.upsert', {ApiField.LINK: link})
        return response.json()[ApiField.ID]

    def upload_batch_paths(self, img_paths, progress_cb=None):
        MAX_BATCH_SIZE = 50
        for batch_paths in batched(img_paths, MAX_BATCH_SIZE):
            content_dict = {}
            for idx, path in enumerate(batch_paths):
                content_dict["{}-file".format(idx)] = (str(idx),
                                                       open(path,
                                                            'rb'), 'image/*')
            encoder = MultipartEncoder(fields=content_dict)
            self.api.post('images.bulk.upload', encoder)
            if progress_cb is not None:
                progress_cb(len(batch_paths))

    def add(self, dataset_id, name, hash):
        response = self.api.post(
            'images.add', {
                ApiField.DATASET_ID: dataset_id,
                ApiField.HASH: hash,
                ApiField.NAME: name
            })
        return self._convert_json_info(response.json())

    def add_link(self, dataset_id, name, link):
        response = self.api.post(
            'images.add', {
                ApiField.DATASET_ID: dataset_id,
                ApiField.LINK: link,
                ApiField.NAME: name
            })
        return self._convert_json_info(response.json())

    def add_batch(self, dataset_id, names, hashes, progress_cb=None):
        # @TOD0: ApiField.NAME
        images = [{
            'title': name,
            ApiField.HASH: hash
        } for name, hash in zip(names, hashes)]
        response = self.api.post('images.bulk.add', {
            ApiField.DATASET_ID: dataset_id,
            ApiField.IMAGES: images
        })
        if progress_cb is not None:
            progress_cb(len(images))
        return [
            self._convert_json_info(info_json)
            for info_json in response.json()
        ]

    def _convert_json_info(self, info: dict):
        if info is None:
            return None
        field_values = []
        for field_name in self.__class__._info_sequence:
            if field_name == ApiField.EXT:
                continue
            field_values.append(info[field_name])
            if field_name == ApiField.MIME:
                field_values.append(info[field_name].split('/')[1])
        return self.__class__.Info._make(field_values)