Esempio n. 1
0
def prepare_datapipe(data_dir):
    from torch.utils.data import IterDataPipe

    # a temp class to do mimic the current dataset behavior
    # so dataloader can use the datapipe directly.
    class TransferDatapipe(IterDataPipe):
        def __init__(self, datapipe, phase, length=-1):
            super().__init__()
            self.datapipe = datapipe
            self.length = length
            self.transform = get_transform_api()[phase]
            self.classes = []
            self.class_ids = {}
            self.curr_id = -1

        def __iter__(self):
            for item in self.datapipe:
                label = item[1][1]['category_id']
                if label not in self.class_ids:
                    self.classes.append(label)
                    self.curr_id = self.curr_id + 1
                    self.class_ids[label] = self.curr_id
                yield (self.transform(item[0][1]), self.class_ids[label])

        def __len__(self):
            if self.length == -1:
                raise NotImplementedError
            return self.length

    datapipe1_t = dp.iter.ListDirFiles(data_dir, 'train*.tar.gz')
    datapipe2_t = dp.iter.LoadFilesFromDisk(datapipe1_t)
    datapipe3_t = dp.iter.ReadFilesFromTar(datapipe2_t)
    count = 0
    for x in datapipe3_t:
        count = count + 1
    datapipe4_t = dp.iter.RoutedDecoder(
        datapipe3_t,
        handlers=[decoder_imagehandler('pilrgb'), decoder_basichandlers])
    datapipe5_t = dp.iter.GroupByKey(datapipe4_t, group_size=2)
    datapipe6_t = TransferDatapipe(datapipe5_t, 'train', int(count / 2))

    datapipe1_v = dp.iter.ListDirFiles(data_dir, 'val*.tar.gz')
    datapipe2_v = dp.iter.LoadFilesFromDisk(datapipe1_v)
    datapipe3_v = dp.iter.ReadFilesFromTar(datapipe2_v)
    count = 0
    for x in datapipe3_v:
        count = count + 1
    datapipe4_v = dp.iter.RoutedDecoder(
        datapipe3_v,
        handlers=[decoder_imagehandler('pilrgb'), decoder_basichandlers])
    datapipe5_v = dp.iter.GroupByKey(datapipe4_v, group_size=2)
    datapipe6_v = TransferDatapipe(datapipe5_v, 'val', int(count / 2))

    return {'train': datapipe6_t, 'val': datapipe6_v}
Esempio n. 2
0
    def test_routeddecoder_iterable_datapipe(self):
        temp_dir = self.temp_dir.name
        temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png")
        img = Image.new('RGB', (2, 2), color='red')
        img.save(temp_pngfile_pathname)
        datapipe1 = dp.iter.ListDirFiles(temp_dir, ['*.png', '*.txt'])
        datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1)

        def _helper(dp, channel_first=False):
            for rec in dp:
                ext = os.path.splitext(rec[0])[1]
                if ext == '.png':
                    expected = np.array([[[1., 0., 0.], [1., 0., 0.]], [[1., 0., 0.], [1., 0., 0.]]], dtype=np.single)
                    if channel_first:
                        expected = expected.transpose(2, 0, 1)
                    self.assertEqual(rec[1], expected)
                else:
                    self.assertTrue(rec[1] == open(rec[0], 'rb').read().decode('utf-8'))

        datapipe3 = dp.iter.RoutedDecoder(datapipe2, decoder_imagehandler('rgb'))
        datapipe3.add_handler(decoder_basichandlers)
        _helper(datapipe3)

        datapipe4 = dp.iter.RoutedDecoder(datapipe2)
        _helper(datapipe4, channel_first=True)
Esempio n. 3
0
 def __init__(self,
              datapipe: Iterable[Tuple[str, BufferedIOBase]],
              *handlers: Callable,
              key_fn: Callable = extension_extract_fn) -> None:
     super().__init__()
     self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
     if not handlers:
         handlers = (decoder_basichandlers, decoder_imagehandler('torch'))
     self.decoder = Decoder(*handlers, key_fn=key_fn)
Esempio n. 4
0
 def __init__(self,
              datapipe: Iterable[Tuple[str, BufferedIOBase]],
              *,
              handlers: Union[None, List[Callable]] = None,
              length: int = -1):
     super().__init__()
     self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
     if handlers:
         self.decoder = Decoder(handlers)
     else:
         self.decoder = Decoder(
             [decoder_basichandlers,
              decoder_imagehandler('torch')])
     self.length: int = length
Esempio n. 5
0
 def __init__(self,
              datapipe: Iterable[Tuple[str, BufferedIOBase]],
              *handlers: Callable,
              key_fn: Callable = extension_extract_fn) -> None:
     super().__init__()
     self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe
     if not handlers:
         handlers = (decoder_basichandlers, decoder_imagehandler('torch'))
     self.decoder = Decoder(*handlers, key_fn=key_fn)
     _deprecation_warning(
         type(self).__name__,
         deprecation_version="1.12",
         removal_version="1.13",
         old_functional_name="routed_decode",
     )
Esempio n. 6
0
    def test_routeddecoder_iterable_datapipe(self):
        temp_dir = self.temp_dir.name
        temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png")
        img = Image.new('RGB', (2, 2), color='red')
        img.save(temp_pngfile_pathname)
        datapipe1 = dp.iter.ListDirFiles(temp_dir, ['*.png', '*.txt'])
        datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1)
        datapipe3 = dp.iter.RoutedDecoder(datapipe2, handlers=[decoder_imagehandler('rgb')])
        datapipe3.add_handler(decoder_basichandlers)

        for rec in datapipe3:
            ext = os.path.splitext(rec[0])[1]
            if ext == '.png':
                expected = np.array([[[1., 0., 0.], [1., 0., 0.]], [[1., 0., 0.], [1., 0., 0.]]], dtype=np.single)
                self.assertTrue(np.array_equal(rec[1], expected))
            else:
                self.assertTrue(rec[1] == open(rec[0], 'rb').read().decode('utf-8'))