Esempio n. 1
0
def _download_cli(cfg, is_cli_call=True):

    tag_name = cfg['tag_name']
    dataset_id = cfg['dataset_id']
    token = cfg['token']

    if not tag_name:
        print('Please specify a tag name')
        print('For help, try: lightly-download --help')
        return

    if not token or not dataset_id:
        print('Please specify your access token and dataset id')
        print('For help, try: lightly-download --help')
        return

    api_workflow_client = ApiWorkflowClient(token=token, dataset_id=dataset_id)

    # get tag id
    tag_name_id_dict = dict([tag.name, tag.id]
                            for tag in api_workflow_client._get_all_tags())
    tag_id = tag_name_id_dict.get(tag_name, None)
    if tag_id is None:
        print(f'The specified tag {tag_name} does not exist.')
        return

    # get tag data
    tag_data = api_workflow_client.tags_api.get_tag_by_tag_id(
        dataset_id=dataset_id, tag_id=tag_id)

    # get samples
    chosen_samples_ids = BitMask.from_hex(tag_data.bit_mask_data).to_indices()
    samples = [
        api_workflow_client.filenames_on_server[i] for i in chosen_samples_ids
    ]

    # store sample names in a .txt file
    with open(cfg['tag_name'] + '.txt', 'w') as f:
        for item in samples:
            f.write("%s\n" % item)

    msg = 'The list of files in tag {} is stored at: '.format(cfg['tag_name'])
    msg += os.path.join(os.getcwd(), cfg['tag_name'] + '.txt')
    print(msg, flush=True)

    if not cfg['input_dir'] and cfg['output_dir']:
        # download full images from api
        output_dir = fix_input_path(cfg['output_dir'])
        api_workflow_client.download_dataset(output_dir, tag_name=tag_name)

    elif cfg['input_dir'] and cfg['output_dir']:
        input_dir = fix_input_path(cfg['input_dir'])
        output_dir = fix_input_path(cfg['output_dir'])
        print(f'Copying files from {input_dir} to {output_dir}.')

        # create a dataset from the input directory
        dataset = data.LightlyDataset(input_dir=input_dir)

        # dump the dataset in the output directory
        dataset.dump(output_dir, samples)
Esempio n. 2
0
    def _get_preselected_tag_bitmask(self):
        """Initializes the preselected tag bitmask.

        """
        if self._preselected_tag_id is None:
            # if not specified, no samples belong to the preselected tag
            preselected_tag_bitmask = BitMask.from_hex('0x0')
        else:
            # get preselected tag from api and set bitmask accordingly
            preselected_tag_data = self.api_workflow_client._tags_api.get_tag_by_tag_id(
                self.api_workflow_client.dataset_id,
                tag_id=self._preselected_tag_id)
            preselected_tag_bitmask = BitMask.from_hex(
                preselected_tag_data.bit_mask_data)

        return preselected_tag_bitmask
Esempio n. 3
0
    def _set_labeled_and_unlabeled_set(self,
                                       preselected_tag_data: TagData = None):
        """Sets the labeled and unlabeled set based on the preselected and query tag id

        It loads the bitmaks for the both tag_ids from the server and then
        extracts the filenames from it given the mapping on the server.

        Args:
            preselected_tag_data:
                optional param, then it must not be loaded from the API

        """
        if self.preselected_tag_id is None:
            self.labeled_set = []
        else:
            if preselected_tag_data is None:
                preselected_tag_data = self.api_workflow_client.tags_api.get_tag_by_tag_id(
                    self.api_workflow_client.dataset_id,
                    tag_id=self.preselected_tag_id)
            chosen_samples_ids = BitMask.from_hex(
                preselected_tag_data.bit_mask_data).to_indices()
            self.labeled_set = [
                self.api_workflow_client.filenames_on_server[i]
                for i in chosen_samples_ids
            ]

        if not hasattr(self, "unlabeled_set"):
            if self.query_tag_id is None:
                self.unlabeled_set = self.api_workflow_client.filenames_on_server
            else:
                query_tag_data = self.api_workflow_client.tags_api.get_tag_by_tag_id(
                    self.api_workflow_client.dataset_id,
                    tag_id=self.query_tag_id)
                chosen_samples_ids = BitMask.from_hex(
                    query_tag_data.bit_mask_data).to_indices()
                self.unlabeled_set = [
                    self.api_workflow_client.filenames_on_server[i]
                    for i in chosen_samples_ids
                ]

        filenames_labeled = set(self.labeled_set)
        self.unlabeled_set = [
            f for f in self.unlabeled_set if f not in filenames_labeled
        ]
Esempio n. 4
0
    def _get_query_tag_bitmask(self):
        """Initializes the query tag bitmask.

        """
        # get query tag from api and set bitmask accordingly
        query_tag_data = self.api_workflow_client._tags_api.get_tag_by_tag_id(
            self.api_workflow_client.dataset_id, tag_id=self._query_tag_id)
        query_tag_bitmask = BitMask.from_hex(query_tag_data.bit_mask_data)

        return query_tag_bitmask
Esempio n. 5
0
    def _set_labeled_and_unlabeled_set(self,
                                       preselected_tag_data: TagData = None):
        """Sets the labeled and unlabeled set based on the preselected and query tag id

        It loads the bitmaks for the both tag_ids from the server and then
        extracts the filenames from it given the mapping on the server.

        Args:
            preselected_tag_data:
                optional param, then it must not be loaded from the API

        """

        if not hasattr(self, "bitmask_labeled_set"):
            self.bitmask_labeled_set = BitMask.from_hex(
                "0x0")  # empty labeled set
            self.bitmask_added_set = BitMask.from_hex("0x0")  # empty added set
        if self.preselected_tag_id is not None:  # else the default values (empty labeled and added set) are kept
            if preselected_tag_data is None:  # if it is not passed as argument, it must be loaded from the API
                preselected_tag_data = self.api_workflow_client.tags_api.get_tag_by_tag_id(
                    self.api_workflow_client.dataset_id,
                    tag_id=self.preselected_tag_id)
            new_bitmask_labeled_set = BitMask.from_hex(
                preselected_tag_data.bit_mask_data)
            self.bitmask_added_set = new_bitmask_labeled_set - self.bitmask_labeled_set
            self.bitmask_labeled_set = new_bitmask_labeled_set

        if self.query_tag_id is None:
            bitmask_query_tag = BitMask.from_length(
                len(self.api_workflow_client.filenames_on_server))
        else:
            query_tag_data = self.api_workflow_client.tags_api.get_tag_by_tag_id(
                self.api_workflow_client.dataset_id, tag_id=self.query_tag_id)
            bitmask_query_tag = BitMask.from_hex(query_tag_data.bit_mask_data)
        self.bitmask_unlabeled_set = bitmask_query_tag - self.bitmask_labeled_set

        self.labeled_set = self.bitmask_labeled_set.masked_select_from_list(
            self.api_workflow_client.filenames_on_server)
        self.added_set = self.bitmask_added_set.masked_select_from_list(
            self.api_workflow_client.filenames_on_server)
        self.unlabeled_set = self.bitmask_unlabeled_set.masked_select_from_list(
            self.api_workflow_client.filenames_on_server)
Esempio n. 6
0
    def get_filenames_in_tag(
        self,
        tag_data: TagData,
        filenames_on_server: List[str] = None,
        exclude_parent_tag: bool = False,
    ) -> List[str]:
        """ Gets the filenames of a tag

        Args:
            tag_data:
                The data of the tag.
            filenames_on_server:
                List of all filenames on the server. If they are not given,
                they need to be downloaded, which is quite expensive.
            exclude_parent_tag:
                Excludes the parent tag in the returned filenames.

        Returns:
            filenames_tag:
                The filenames of all samples in the tag.

        """

        if exclude_parent_tag:
            parent_tag_id = tag_data.prev_tag_id
            tag_arithmetics_request = TagArithmeticsRequest(
                tag_id1=tag_data.id,
                tag_id2=parent_tag_id,
                operation=TagArithmeticsOperation.DIFFERENCE)
            bit_mask_response: TagBitMaskResponse = \
                self._tags_api.perform_tag_arithmetics_bitmask(
                    body=tag_arithmetics_request, dataset_id=self.dataset_id
                )
            bit_mask_data = bit_mask_response.bit_mask_data
        else:
            bit_mask_data = tag_data.bit_mask_data

        if not filenames_on_server:
            filenames_on_server = self.get_filenames()

        filenames_tag = BitMask.from_hex(bit_mask_data).\
            masked_select_from_list(filenames_on_server)

        return filenames_tag
Esempio n. 7
0
    def test_store_and_retrieve(self):

        x = int("0b01010100100100100100100010010100100100101001001010101010", 2)
        mask = BitMask(x)
        mask.set_kth_bit(11)
        mask.set_kth_bit(22)
        mask.set_kth_bit(33)
        mask.set_kth_bit(44)
        mask.set_kth_bit(55)
        mask.set_kth_bit(66)
        mask.set_kth_bit(77)
        mask.set_kth_bit(88)
        mask.set_kth_bit(99)

        somewhere = mask.to_hex()
        somewhere_else = mask.to_bin()

        mask_somewhere = BitMask.from_hex(somewhere)
        mask_somewhere_else = BitMask.from_bin(somewhere_else)

        self.assertEqual(mask.x, mask_somewhere.x)
        self.assertEqual(mask.x, mask_somewhere_else.x)
Esempio n. 8
0
def _download_cli(cfg, is_cli_call=True):
    tag_name = cfg['tag_name']
    dataset_id = cfg['dataset_id']
    token = cfg['token']

    if not tag_name or not token or not dataset_id:
        print_as_warning(
            'Please specify all of the parameters tag_name, token and dataset_id'
        )
        print_as_warning('For help, try: lightly-download --help')
        return

    api_workflow_client = ApiWorkflowClient(token=token, dataset_id=dataset_id)

    # get tag id
    tag_name_id_dict = dict([tag.name, tag.id]
                            for tag in api_workflow_client._get_all_tags())
    tag_id = tag_name_id_dict.get(tag_name, None)
    if tag_id is None:
        warnings.warn(f'The specified tag {tag_name} does not exist.')
        return

    # get tag data
    tag_data: TagData = api_workflow_client.tags_api.get_tag_by_tag_id(
        dataset_id=dataset_id, tag_id=tag_id)

    if cfg["exclude_parent_tag"]:
        parent_tag_id = tag_data.prev_tag_id
        tag_arithmetics_request = TagArithmeticsRequest(
            tag_id1=tag_data.id,
            tag_id2=parent_tag_id,
            operation=TagArithmeticsOperation.DIFFERENCE)
        bit_mask_response: TagBitMaskResponse \
            = api_workflow_client.tags_api.perform_tag_arithmetics(body=tag_arithmetics_request, dataset_id=dataset_id)
        bit_mask_data = bit_mask_response.bit_mask_data
    else:
        bit_mask_data = tag_data.bit_mask_data

    # get samples
    chosen_samples_ids = BitMask.from_hex(bit_mask_data).to_indices()
    samples = [
        api_workflow_client.filenames_on_server[i] for i in chosen_samples_ids
    ]

    # store sample names in a .txt file
    filename = cfg['tag_name'] + '.txt'
    with open(filename, 'w') as f:
        for item in samples:
            f.write("%s\n" % item)

    filepath = os.path.join(os.getcwd(), filename)
    msg = f'The list of files in tag {cfg["tag_name"]} is stored at: {bcolors.OKBLUE}{filepath}{bcolors.ENDC}'
    print(msg, flush=True)

    if not cfg['input_dir'] and cfg['output_dir']:
        # download full images from api
        output_dir = fix_input_path(cfg['output_dir'])
        api_workflow_client.download_dataset(output_dir, tag_name=tag_name)

    elif cfg['input_dir'] and cfg['output_dir']:
        input_dir = fix_input_path(cfg['input_dir'])
        output_dir = fix_input_path(cfg['output_dir'])
        print(
            f'Copying files from {input_dir} to {bcolors.OKBLUE}{output_dir}{bcolors.ENDC}.'
        )

        # create a dataset from the input directory
        dataset = data.LightlyDataset(input_dir=input_dir)

        # dump the dataset in the output directory
        dataset.dump(output_dir, samples)
    def download_dataset(self,
                         output_dir: str,
                         tag_name: str = 'initial-tag',
                         verbose: bool = True):
        """Downloads images from the web-app and stores them in output_dir.

        Args:
            output_dir:
                Where to store the downloaded images.
            tag_name:
                Name of the tag which should be downloaded.
            verbose:
                Whether or not to show the progress bar.

        Raises:
            ValueError if the specified tag does not exist on the dataset.
            RuntimeError if the connection to the server failed.

        """

        # check if images are available
        dataset = self.datasets_api.get_dataset_by_id(self.dataset_id)
        if dataset.img_type != ImageType.FULL:
            # only thumbnails or metadata available
            raise ValueError(
                f"Dataset with id {self.dataset_id} has no downloadable images!"
            )

        # check if tag exists
        available_tags = self._get_all_tags()
        try:
            tag = next(tag for tag in available_tags if tag.name == tag_name)
        except StopIteration:
            raise ValueError(
                f"Dataset with id {self.dataset_id} has no tag {tag_name}!")

        # get sample ids
        sample_ids = self.mappings_api.get_sample_mappings_by_dataset_id(
            self.dataset_id, field='_id')

        indices = BitMask.from_hex(tag.bit_mask_data).to_indices()
        sample_ids = [sample_ids[i] for i in indices]
        filenames = [self.filenames_on_server[i] for i in indices]

        if verbose:
            print(f'Downloading {len(sample_ids)} images:', flush=True)
            pbar = tqdm.tqdm(unit='imgs', total=len(sample_ids))

        # download images
        for sample_id, filename in zip(sample_ids, filenames):
            read_url = self.samples_api.get_sample_image_read_url_by_id(
                self.dataset_id,
                sample_id,
                type="full",
            )

            img = _get_image_from_read_url(read_url)
            _make_dir_and_save_image(output_dir, filename, img)

            if verbose:
                pbar.update(1)