def test_get_samples_no_tag(self): '''Make sure everything works in error scenario. ''' self.setup(psuccess=1., mode='no_tag_exists') responses.add_callback(responses.GET, self.tag_url, callback=self.tags_callback, content_type='application/json') responses.add_callback(responses.GET, self.dst_url, callback=self.samples_callback, content_type='application/json') with self.assertRaises(RuntimeError): samples = get_samples(self.dataset_id, self.token) for s0, s1 in zip(samples, self.samples.splitlines()): self.assertEqual(s0, s1)
def test_get_samples_some_success(self): '''Make sure everything works in some error scenario. ''' self.setup(psuccess=.9) responses.add_callback(responses.GET, self.tag_url, callback=self.tags_callback, content_type='application/json') responses.add_callback(responses.GET, self.dst_url, callback=self.samples_callback, content_type='application/json') for i in range(N): samples = get_samples(self.dataset_id, self.token) for s0, s1 in zip(samples, self.samples.splitlines()): self.assertEqual(s0, s1)
def get_samples_by_tag(tag_name: str, dataset_id: str, token: str, mode: str = 'list', filenames: List[str] = None): """Get the files associated with a given tag and dataset. Asks the servers for all samples in a given tag and dataset. If mode is mask or indices, the list of all filenames must be specified. Can return either the list of all filenames in the tag, or a mask or indices indicating which of the provided filenames are in the tag. Args: tag_name: Name of the tag to query. dataset_id: The unique identifier of the dataset. token: Token for authentication. mode: Return type, must be in ["list", "mask", "indices"]. filenames: List of all filenames. Returns: Either list of filenames, binary mask, or list of indices specifying the samples in the requested tag. Raises: ValueError, RuntimeError """ if mode == 'mask' and filenames is None: msg = f'Argument filenames must not be None for mode "{mode}"!' raise ValueError(msg) if mode == 'indices' and filenames is None: msg = f'Argument filenames must not be None for mode "{mode}"!' raise ValueError(msg) samples = get_samples(dataset_id, token, tag_name=tag_name) if mode == 'list': return samples if mode == 'mask': mask = [1 if f in set(samples) else 0 for f in filenames] if sum(mask) != len(samples): msg = 'Error during mapping from samples to filenames: ' msg += f'sum(mask) != len(samples) with lengths ' msg += f'{sum(mask)} and {len(samples)}' raise RuntimeError(msg) return mask if mode == 'indices': indices = [i for i in range(len(filenames))] indices = filter(lambda i: filenames[i] in set(samples), indices) indices = list(indices) if len(indices) != len(samples): msg = 'Error during mapping from samples to filenames: ' msg += f'len(indices) != len(samples) with lengths ' msg += f'{len(indices)} and {len(samples)}.' raise RuntimeError(msg) return indices msg = f'Got illegal mode "{mode}"! ' msg += 'Must be in ["list", "mask", "indices"]' raise ValueError(msg)