예제 #1
0
 def test_operator_minus(self):
     mask_a = BitMask.from_bin("0b10111")
     mask_a_old = deepcopy(mask_a)
     mask_b = BitMask.from_bin("0b01100")
     mask_target = BitMask.from_bin("0b10011")
     diff = mask_a - mask_b
     self.assertEqual(diff, mask_target)
     self.assertEqual(mask_a_old, mask_a)  # make sure the original mask is unchanged.
예제 #2
0
def _download_cli(cfg, is_cli_call=True):

    tag_name = cfg['tag_name']
    dataset_id = cfg['dataset_id']
    token = cfg['token']

    if not tag_name:
        print('Please specify a tag name')
        print('For help, try: lightly-download --help')
        return

    if not token or not dataset_id:
        print('Please specify your access token and dataset id')
        print('For help, try: lightly-download --help')
        return

    api_workflow_client = ApiWorkflowClient(token=token, dataset_id=dataset_id)

    # get tag id
    tag_name_id_dict = dict([tag.name, tag.id]
                            for tag in api_workflow_client._get_all_tags())
    tag_id = tag_name_id_dict.get(tag_name, None)
    if tag_id is None:
        print(f'The specified tag {tag_name} does not exist.')
        return

    # get tag data
    tag_data = api_workflow_client.tags_api.get_tag_by_tag_id(
        dataset_id=dataset_id, tag_id=tag_id)

    # get samples
    chosen_samples_ids = BitMask.from_hex(tag_data.bit_mask_data).to_indices()
    samples = [
        api_workflow_client.filenames_on_server[i] for i in chosen_samples_ids
    ]

    # store sample names in a .txt file
    with open(cfg['tag_name'] + '.txt', 'w') as f:
        for item in samples:
            f.write("%s\n" % item)

    msg = 'The list of files in tag {} is stored at: '.format(cfg['tag_name'])
    msg += os.path.join(os.getcwd(), cfg['tag_name'] + '.txt')
    print(msg, flush=True)

    if not cfg['input_dir'] and cfg['output_dir']:
        # download full images from api
        output_dir = fix_input_path(cfg['output_dir'])
        api_workflow_client.download_dataset(output_dir, tag_name=tag_name)

    elif cfg['input_dir'] and cfg['output_dir']:
        input_dir = fix_input_path(cfg['input_dir'])
        output_dir = fix_input_path(cfg['output_dir'])
        print(f'Copying files from {input_dir} to {output_dir}.')

        # create a dataset from the input directory
        dataset = data.LightlyDataset(input_dir=input_dir)

        # dump the dataset in the output directory
        dataset.dump(output_dir, samples)
예제 #3
0
    def test_get_and_set_outside_of_range(self):

        mask = BitMask.from_bin("0b11110000")

        self.assertFalse(mask.get_kth_bit(100))
        mask.set_kth_bit(100)
        self.assertTrue(mask.get_kth_bit(100))
예제 #4
0
    def _get_preselected_tag_bitmask(self):
        """Initializes the preselected tag bitmask.

        """
        if self._preselected_tag_id is None:
            # if not specified, no samples belong to the preselected tag
            preselected_tag_bitmask = BitMask.from_hex('0x0')
        else:
            # get preselected tag from api and set bitmask accordingly
            preselected_tag_data = self.api_workflow_client._tags_api.get_tag_by_tag_id(
                self.api_workflow_client.dataset_id,
                tag_id=self._preselected_tag_id)
            preselected_tag_bitmask = BitMask.from_hex(
                preselected_tag_data.bit_mask_data)

        return preselected_tag_bitmask
예제 #5
0
파일: agent.py 프로젝트: stjordanis/lightly
    def _set_labeled_and_unlabeled_set(self,
                                       preselected_tag_data: TagData = None):
        """Sets the labeled and unlabeled set based on the preselected and query tag id

        It loads the bitmaks for the both tag_ids from the server and then
        extracts the filenames from it given the mapping on the server.

        Args:
            preselected_tag_data:
                optional param, then it must not be loaded from the API

        """
        if self.preselected_tag_id is None:
            self.labeled_set = []
        else:
            if preselected_tag_data is None:
                preselected_tag_data = self.api_workflow_client.tags_api.get_tag_by_tag_id(
                    self.api_workflow_client.dataset_id,
                    tag_id=self.preselected_tag_id)
            chosen_samples_ids = BitMask.from_hex(
                preselected_tag_data.bit_mask_data).to_indices()
            self.labeled_set = [
                self.api_workflow_client.filenames_on_server[i]
                for i in chosen_samples_ids
            ]

        if not hasattr(self, "unlabeled_set"):
            if self.query_tag_id is None:
                self.unlabeled_set = self.api_workflow_client.filenames_on_server
            else:
                query_tag_data = self.api_workflow_client.tags_api.get_tag_by_tag_id(
                    self.api_workflow_client.dataset_id,
                    tag_id=self.query_tag_id)
                chosen_samples_ids = BitMask.from_hex(
                    query_tag_data.bit_mask_data).to_indices()
                self.unlabeled_set = [
                    self.api_workflow_client.filenames_on_server[i]
                    for i in chosen_samples_ids
                ]

        filenames_labeled = set(self.labeled_set)
        self.unlabeled_set = [
            f for f in self.unlabeled_set if f not in filenames_labeled
        ]
예제 #6
0
    def _get_query_tag_bitmask(self):
        """Initializes the query tag bitmask.

        """
        # get query tag from api and set bitmask accordingly
        query_tag_data = self.api_workflow_client._tags_api.get_tag_by_tag_id(
            self.api_workflow_client.dataset_id, tag_id=self._query_tag_id)
        query_tag_bitmask = BitMask.from_hex(query_tag_data.bit_mask_data)

        return query_tag_bitmask
예제 #7
0
    def _set_labeled_and_unlabeled_set(self,
                                       preselected_tag_data: TagData = None):
        """Sets the labeled and unlabeled set based on the preselected and query tag id

        It loads the bitmaks for the both tag_ids from the server and then
        extracts the filenames from it given the mapping on the server.

        Args:
            preselected_tag_data:
                optional param, then it must not be loaded from the API

        """

        if not hasattr(self, "bitmask_labeled_set"):
            self.bitmask_labeled_set = BitMask.from_hex(
                "0x0")  # empty labeled set
            self.bitmask_added_set = BitMask.from_hex("0x0")  # empty added set
        if self.preselected_tag_id is not None:  # else the default values (empty labeled and added set) are kept
            if preselected_tag_data is None:  # if it is not passed as argument, it must be loaded from the API
                preselected_tag_data = self.api_workflow_client.tags_api.get_tag_by_tag_id(
                    self.api_workflow_client.dataset_id,
                    tag_id=self.preselected_tag_id)
            new_bitmask_labeled_set = BitMask.from_hex(
                preselected_tag_data.bit_mask_data)
            self.bitmask_added_set = new_bitmask_labeled_set - self.bitmask_labeled_set
            self.bitmask_labeled_set = new_bitmask_labeled_set

        if self.query_tag_id is None:
            bitmask_query_tag = BitMask.from_length(
                len(self.api_workflow_client.filenames_on_server))
        else:
            query_tag_data = self.api_workflow_client.tags_api.get_tag_by_tag_id(
                self.api_workflow_client.dataset_id, tag_id=self.query_tag_id)
            bitmask_query_tag = BitMask.from_hex(query_tag_data.bit_mask_data)
        self.bitmask_unlabeled_set = bitmask_query_tag - self.bitmask_labeled_set

        self.labeled_set = self.bitmask_labeled_set.masked_select_from_list(
            self.api_workflow_client.filenames_on_server)
        self.added_set = self.bitmask_added_set.masked_select_from_list(
            self.api_workflow_client.filenames_on_server)
        self.unlabeled_set = self.bitmask_unlabeled_set.masked_select_from_list(
            self.api_workflow_client.filenames_on_server)
예제 #8
0
    def test_get_and_set(self):

        mask = BitMask.from_bin("0b11110000")

        self.assertFalse(mask.get_kth_bit(2))
        mask.set_kth_bit(2)
        self.assertTrue(mask.get_kth_bit(2))

        self.assertTrue(mask.get_kth_bit(4))
        mask.unset_kth_bit(4)
        self.assertFalse(mask.get_kth_bit(4))
예제 #9
0
    def test_inverse(self):
        # TODO: proper implementation
        return

        x = int("0b11110000", 2)
        y = int("0b00001111", 2)
        mask = BitMask(x)
        mask.invert()
        self.assertEqual(mask.x, y)

        x = int("0b010101010101010101", 2)
        y = int("0b101010101010101010", 2)
        mask = BitMask(x)
        mask.invert()
        self.assertEqual(mask.x, y)
예제 #10
0
    def test_nonzero_bits(self):

        mask = BitMask.from_bin("0b0")
        indices = [100, 1000, 10_000, 100_000]

        self.assertEqual(mask.x, 0)
        for index in indices:
            mask.set_kth_bit(index)

        self.assertGreaterEqual(mask.x, 0)
        also_indices = mask.to_indices()

        for i, j in zip(indices, also_indices):
            self.assertEqual(i, j)
예제 #11
0
    def test_masked_select_from_list(self):
        n = 1000
        list_ = [randint(0, 1) for _ in range(n - 2)] + [0, 1]
        mask = BitMask.from_length(n)
        for index, item_ in enumerate(list_):
            if item_ == 0:
                mask.unset_kth_bit(index)
            else:
                mask.set_kth_bit(index)

        all_ones = mask.masked_select_from_list(list_)
        mask.invert(n)
        all_zeros = mask.masked_select_from_list(list_)
        self.assertGreater(len(all_ones), 0)
        self.assertGreater(len(all_zeros), 0)
        self.assertTrue(all([item_ > 0 for item_ in all_ones]))
        self.assertTrue(all([item_ == 0 for item_ in all_zeros]))
예제 #12
0
    def test_invert(self):
        # get random bitstring
        length = 10
        bitstring = self.random_bitstring(10)
 
        #get inverse
        mask = BitMask.from_bin(bitstring)
        mask.invert(length)
        inverted = mask.to_bin()

        # remove 0b
        inverted = inverted[2:]
        bitstring = bitstring[2:]
        for i in range(min(len(bitstring), len(inverted))):
            if bitstring[-i - 1] == '0':
                self.assertEqual(inverted[-i - 1], '1')
            else:
                self.assertEqual(inverted[-i - 1], '0')
예제 #13
0
    def get_filenames_in_tag(
        self,
        tag_data: TagData,
        filenames_on_server: List[str] = None,
        exclude_parent_tag: bool = False,
    ) -> List[str]:
        """ Gets the filenames of a tag

        Args:
            tag_data:
                The data of the tag.
            filenames_on_server:
                List of all filenames on the server. If they are not given,
                they need to be downloaded, which is quite expensive.
            exclude_parent_tag:
                Excludes the parent tag in the returned filenames.

        Returns:
            filenames_tag:
                The filenames of all samples in the tag.

        """

        if exclude_parent_tag:
            parent_tag_id = tag_data.prev_tag_id
            tag_arithmetics_request = TagArithmeticsRequest(
                tag_id1=tag_data.id,
                tag_id2=parent_tag_id,
                operation=TagArithmeticsOperation.DIFFERENCE)
            bit_mask_response: TagBitMaskResponse = \
                self._tags_api.perform_tag_arithmetics_bitmask(
                    body=tag_arithmetics_request, dataset_id=self.dataset_id
                )
            bit_mask_data = bit_mask_response.bit_mask_data
        else:
            bit_mask_data = tag_data.bit_mask_data

        if not filenames_on_server:
            filenames_on_server = self.get_filenames()

        filenames_tag = BitMask.from_hex(bit_mask_data).\
            masked_select_from_list(filenames_on_server)

        return filenames_tag
예제 #14
0
    def create_tag_from_filenames(self,
                                  fnames_new_tag: List[str],
                                  new_tag_name: str,
                                  parent_tag_id: str = None) -> TagData:
        """Creates a new tag from a list of filenames.

        Args:
            fnames_new_tag:
                A list of filenames to be included in the new tag.
            new_tag_name:
                The name of the new tag.
            parent_tag_id:
                The tag defining where to sample from, default: None resolves to the initial-tag.

        Returns:
            The newly created tag.

        Raises:
            RuntimeError
        """

        # make sure the tag name does not exist yet
        tags = self.get_all_tags()
        if new_tag_name in [tag.name for tag in tags]:
            raise RuntimeError(
                f'There already exists a tag with tag_name {new_tag_name}.')
        if len(tags) == 0:
            raise RuntimeError('There exists no initial-tag for this dataset.')

        # fallback to initial tag if no parent tag is provided
        if parent_tag_id is None:
            parent_tag_id = next(tag.id for tag in tags
                                 if tag.name == 'initial-tag')

        # get list of filenames from tag
        fnames_server = self.get_filenames()
        tot_size = len(fnames_server)

        # create new bitmask for the new tag
        bitmask = BitMask(0)
        fnames_new_tag = set(fnames_new_tag)
        for i, fname in enumerate(fnames_server):
            if fname in fnames_new_tag:
                bitmask.set_kth_bit(i)

        # quick sanity check
        num_selected_samples = len(bitmask.to_indices())
        if num_selected_samples != len(fnames_new_tag):
            raise RuntimeError(
                f'An error occured when creating the new subset! '
                f'Out of the {len(fnames_new_tag)} filenames you provided '
                f'to create a new tag, only {num_selected_samples} have been '
                f'found on the server. '
                f'Make sure you use the correct filenames. '
                f'Valid filename example from the dataset: {fnames_server[0]}')

        # create new tag
        tag_data_dict = {
            'name': new_tag_name,
            'prevTagId': parent_tag_id,
            'bitMaskData': bitmask.to_hex(),
            'totSize': tot_size
        }

        new_tag = self._tags_api.create_tag_by_dataset_id(
            tag_data_dict, self.dataset_id)

        return new_tag
예제 #15
0
 def assert_difference(self, bistring_1: str, bitstring_2: str, target: str):
     mask_a = BitMask.from_bin(bistring_1)
     mask_b = BitMask.from_bin(bitstring_2)
     mask_a.difference(mask_b)
     self.assertEqual(mask_a.x, int(target, 2))
예제 #16
0
def _download_cli(cfg, is_cli_call=True):
    tag_name = cfg['tag_name']
    dataset_id = cfg['dataset_id']
    token = cfg['token']

    if not tag_name or not token or not dataset_id:
        print_as_warning(
            'Please specify all of the parameters tag_name, token and dataset_id'
        )
        print_as_warning('For help, try: lightly-download --help')
        return

    api_workflow_client = ApiWorkflowClient(token=token, dataset_id=dataset_id)

    # get tag id
    tag_name_id_dict = dict([tag.name, tag.id]
                            for tag in api_workflow_client._get_all_tags())
    tag_id = tag_name_id_dict.get(tag_name, None)
    if tag_id is None:
        warnings.warn(f'The specified tag {tag_name} does not exist.')
        return

    # get tag data
    tag_data: TagData = api_workflow_client.tags_api.get_tag_by_tag_id(
        dataset_id=dataset_id, tag_id=tag_id)

    if cfg["exclude_parent_tag"]:
        parent_tag_id = tag_data.prev_tag_id
        tag_arithmetics_request = TagArithmeticsRequest(
            tag_id1=tag_data.id,
            tag_id2=parent_tag_id,
            operation=TagArithmeticsOperation.DIFFERENCE)
        bit_mask_response: TagBitMaskResponse \
            = api_workflow_client.tags_api.perform_tag_arithmetics(body=tag_arithmetics_request, dataset_id=dataset_id)
        bit_mask_data = bit_mask_response.bit_mask_data
    else:
        bit_mask_data = tag_data.bit_mask_data

    # get samples
    chosen_samples_ids = BitMask.from_hex(bit_mask_data).to_indices()
    samples = [
        api_workflow_client.filenames_on_server[i] for i in chosen_samples_ids
    ]

    # store sample names in a .txt file
    filename = cfg['tag_name'] + '.txt'
    with open(filename, 'w') as f:
        for item in samples:
            f.write("%s\n" % item)

    filepath = os.path.join(os.getcwd(), filename)
    msg = f'The list of files in tag {cfg["tag_name"]} is stored at: {bcolors.OKBLUE}{filepath}{bcolors.ENDC}'
    print(msg, flush=True)

    if not cfg['input_dir'] and cfg['output_dir']:
        # download full images from api
        output_dir = fix_input_path(cfg['output_dir'])
        api_workflow_client.download_dataset(output_dir, tag_name=tag_name)

    elif cfg['input_dir'] and cfg['output_dir']:
        input_dir = fix_input_path(cfg['input_dir'])
        output_dir = fix_input_path(cfg['output_dir'])
        print(
            f'Copying files from {input_dir} to {bcolors.OKBLUE}{output_dir}{bcolors.ENDC}.'
        )

        # create a dataset from the input directory
        dataset = data.LightlyDataset(input_dir=input_dir)

        # dump the dataset in the output directory
        dataset.dump(output_dir, samples)
예제 #17
0
 def test_union(self):
     mask_a = BitMask.from_bin("0b001")
     mask_b = BitMask.from_bin("0b100")
     mask_a.union(mask_b)
     self.assertEqual(mask_a.x, int("0b101", 2))
예제 #18
0
 def test_intersection(self):
     mask_a = BitMask.from_bin("0b101")
     mask_b = BitMask.from_bin("0b100")
     mask_a.intersection(mask_b)
     self.assertEqual(mask_a.x, int("0b100", 2))
    def download_dataset(self,
                         output_dir: str,
                         tag_name: str = 'initial-tag',
                         verbose: bool = True):
        """Downloads images from the web-app and stores them in output_dir.

        Args:
            output_dir:
                Where to store the downloaded images.
            tag_name:
                Name of the tag which should be downloaded.
            verbose:
                Whether or not to show the progress bar.

        Raises:
            ValueError if the specified tag does not exist on the dataset.
            RuntimeError if the connection to the server failed.

        """

        # check if images are available
        dataset = self.datasets_api.get_dataset_by_id(self.dataset_id)
        if dataset.img_type != ImageType.FULL:
            # only thumbnails or metadata available
            raise ValueError(
                f"Dataset with id {self.dataset_id} has no downloadable images!"
            )

        # check if tag exists
        available_tags = self._get_all_tags()
        try:
            tag = next(tag for tag in available_tags if tag.name == tag_name)
        except StopIteration:
            raise ValueError(
                f"Dataset with id {self.dataset_id} has no tag {tag_name}!")

        # get sample ids
        sample_ids = self.mappings_api.get_sample_mappings_by_dataset_id(
            self.dataset_id, field='_id')

        indices = BitMask.from_hex(tag.bit_mask_data).to_indices()
        sample_ids = [sample_ids[i] for i in indices]
        filenames = [self.filenames_on_server[i] for i in indices]

        if verbose:
            print(f'Downloading {len(sample_ids)} images:', flush=True)
            pbar = tqdm.tqdm(unit='imgs', total=len(sample_ids))

        # download images
        for sample_id, filename in zip(sample_ids, filenames):
            read_url = self.samples_api.get_sample_image_read_url_by_id(
                self.dataset_id,
                sample_id,
                type="full",
            )

            img = _get_image_from_read_url(read_url)
            _make_dir_and_save_image(output_dir, filename, img)

            if verbose:
                pbar.update(1)
예제 #20
0
    def test_store_and_retrieve(self):

        x = int("0b01010100100100100100100010010100100100101001001010101010", 2)
        mask = BitMask(x)
        mask.set_kth_bit(11)
        mask.set_kth_bit(22)
        mask.set_kth_bit(33)
        mask.set_kth_bit(44)
        mask.set_kth_bit(55)
        mask.set_kth_bit(66)
        mask.set_kth_bit(77)
        mask.set_kth_bit(88)
        mask.set_kth_bit(99)

        somewhere = mask.to_hex()
        somewhere_else = mask.to_bin()

        mask_somewhere = BitMask.from_hex(somewhere)
        mask_somewhere_else = BitMask.from_bin(somewhere_else)

        self.assertEqual(mask.x, mask_somewhere.x)
        self.assertEqual(mask.x, mask_somewhere_else.x)
예제 #21
0
 def test_subset_a_list(self):
     list_ = [4, 7, 9, 1]
     mask = BitMask.from_bin("0b0101")
     target_masked_list = [7, 1]
     masked_list = mask.masked_select_from_list(list_)
     self.assertEqual(target_masked_list, masked_list)
예제 #22
0
 def test_bitmask_from_length(self):
     length = 4
     mask = BitMask.from_length(length)
     self.assertEqual(mask.to_bin(), "0b1111")
예제 #23
0
 def test_large_bitmasks(self):
     bitstring = "0b" + "1" * 5678
     mask = BitMask.from_bin(bitstring)
     mask_as_bitstring = mask.to_bin()
     self.assertEqual(mask_as_bitstring, bitstring)
예제 #24
0
 def test_equal(self):
     mask_a = BitMask.from_bin("0b101")
     mask_b = BitMask.from_bin("0b101")
     self.assertEqual(mask_a, mask_b)
예제 #25
0
 def test_masked_select_from_list_example(self):
     list_ = [1, 2, 3, 4, 5, 6]
     mask = BitMask.from_bin('0b001101') # expected result is [1, 3, 4]
     selected = mask.masked_select_from_list(list_)
     self.assertListEqual(selected, [1, 3, 4])