Example #1
0
    def test_get_samples_no_tag(self):
        '''Make sure everything works in error scenario.

        '''

        self.setup(psuccess=1., mode='no_tag_exists')

        responses.add_callback(responses.GET,
                               self.tag_url,
                               callback=self.tags_callback,
                               content_type='application/json')

        responses.add_callback(responses.GET,
                               self.dst_url,
                               callback=self.samples_callback,
                               content_type='application/json')

        with self.assertRaises(RuntimeError):
            samples = get_samples(self.dataset_id, self.token)
            for s0, s1 in zip(samples, self.samples.splitlines()):
                self.assertEqual(s0, s1)
Example #2
0
    def test_get_samples_some_success(self):
        '''Make sure everything works in some error scenario.

        '''

        self.setup(psuccess=.9)

        responses.add_callback(responses.GET,
                               self.tag_url,
                               callback=self.tags_callback,
                               content_type='application/json')

        responses.add_callback(responses.GET,
                               self.dst_url,
                               callback=self.samples_callback,
                               content_type='application/json')

        for i in range(N):
            samples = get_samples(self.dataset_id, self.token)
            for s0, s1 in zip(samples, self.samples.splitlines()):
                self.assertEqual(s0, s1)
Example #3
0
def get_samples_by_tag(tag_name: str,
                       dataset_id: str,
                       token: str,
                       mode: str = 'list',
                       filenames: List[str] = None):
    """Get the files associated with a given tag and dataset.

    Asks the servers for all samples in a given tag and dataset. If mode is
    mask or indices, the list of all filenames must be specified. Can return
    either the list of all filenames in the tag, or a mask or indices 
    indicating which of the provided filenames are in the tag.

    Args:
        tag_name:
            Name of the tag to query.
        dataset_id:
            The unique identifier of the dataset.
        token:
            Token for authentication.
        mode:
            Return type, must be in ["list", "mask", "indices"].
        filenames:
            List of all filenames.

    Returns:
        Either list of filenames, binary mask, or list of indices
        specifying the samples in the requested tag.

    Raises:
        ValueError, RuntimeError

    """

    if mode == 'mask' and filenames is None:
        msg = f'Argument filenames must not be None for mode "{mode}"!'
        raise ValueError(msg)
    if mode == 'indices' and filenames is None:
        msg = f'Argument filenames must not be None for mode "{mode}"!'
        raise ValueError(msg)

    samples = get_samples(dataset_id, token, tag_name=tag_name)

    if mode == 'list':
        return samples

    if mode == 'mask':
        mask = [1 if f in set(samples) else 0 for f in filenames]
        if sum(mask) != len(samples):
            msg = 'Error during mapping from samples to filenames: '
            msg += f'sum(mask) != len(samples) with lengths '
            msg += f'{sum(mask)} and {len(samples)}'
            raise RuntimeError(msg)
        return mask

    if mode == 'indices':
        indices = [i for i in range(len(filenames))]
        indices = filter(lambda i: filenames[i] in set(samples), indices)
        indices = list(indices)
        if len(indices) != len(samples):
            msg = 'Error during mapping from samples to filenames: '
            msg += f'len(indices) != len(samples) with lengths '
            msg += f'{len(indices)} and {len(samples)}.'
            raise RuntimeError(msg)
        return indices

    msg = f'Got illegal mode "{mode}"! '
    msg += 'Must be in ["list", "mask", "indices"]'
    raise ValueError(msg)