def test_masked_select_from_list(self): n = 1000 list_ = [randint(0, 1) for _ in range(n - 2)] + [0, 1] mask = BitMask.from_length(n) for index, item_ in enumerate(list_): if item_ == 0: mask.unset_kth_bit(index) else: mask.set_kth_bit(index) all_ones = mask.masked_select_from_list(list_) mask.invert(n) all_zeros = mask.masked_select_from_list(list_) self.assertGreater(len(all_ones), 0) self.assertGreater(len(all_zeros), 0) self.assertTrue(all([item_ > 0 for item_ in all_ones])) self.assertTrue(all([item_ == 0 for item_ in all_zeros]))
def _set_labeled_and_unlabeled_set(self, preselected_tag_data: TagData = None): """Sets the labeled and unlabeled set based on the preselected and query tag id It loads the bitmaks for the both tag_ids from the server and then extracts the filenames from it given the mapping on the server. Args: preselected_tag_data: optional param, then it must not be loaded from the API """ if not hasattr(self, "bitmask_labeled_set"): self.bitmask_labeled_set = BitMask.from_hex( "0x0") # empty labeled set self.bitmask_added_set = BitMask.from_hex("0x0") # empty added set if self.preselected_tag_id is not None: # else the default values (empty labeled and added set) are kept if preselected_tag_data is None: # if it is not passed as argument, it must be loaded from the API preselected_tag_data = self.api_workflow_client.tags_api.get_tag_by_tag_id( self.api_workflow_client.dataset_id, tag_id=self.preselected_tag_id) new_bitmask_labeled_set = BitMask.from_hex( preselected_tag_data.bit_mask_data) self.bitmask_added_set = new_bitmask_labeled_set - self.bitmask_labeled_set self.bitmask_labeled_set = new_bitmask_labeled_set if self.query_tag_id is None: bitmask_query_tag = BitMask.from_length( len(self.api_workflow_client.filenames_on_server)) else: query_tag_data = self.api_workflow_client.tags_api.get_tag_by_tag_id( self.api_workflow_client.dataset_id, tag_id=self.query_tag_id) bitmask_query_tag = BitMask.from_hex(query_tag_data.bit_mask_data) self.bitmask_unlabeled_set = bitmask_query_tag - self.bitmask_labeled_set self.labeled_set = self.bitmask_labeled_set.masked_select_from_list( self.api_workflow_client.filenames_on_server) self.added_set = self.bitmask_added_set.masked_select_from_list( self.api_workflow_client.filenames_on_server) self.unlabeled_set = self.bitmask_unlabeled_set.masked_select_from_list( self.api_workflow_client.filenames_on_server)
def test_bitmask_from_length(self): length = 4 mask = BitMask.from_length(length) self.assertEqual(mask.to_bin(), "0b1111")