예제 #1
0
 def _split_generators(self, dl_manager):
     del dl_manager
     return [
         splits_lib.SplitGenerator(
             name=splits_lib.Split.TRAIN,
             num_shards=10,
             gen_kwargs=dict(num_examples=1000),
         ),
         splits_lib.SplitGenerator(
             name=splits_lib.Split.TEST,
             num_shards=None,  # Use liquid sharing.
             gen_kwargs=dict(num_examples=725),
         ),
     ]
예제 #2
0
 def _split_generators(self, dl_manager):
   return [
       splits.SplitGenerator(
           name=splits.Split.TRAIN,
           num_shards=2,
           gen_kwargs={"range_": range(20)}),
       splits.SplitGenerator(
           name=splits.Split.VALIDATION,
           num_shards=1,
           gen_kwargs={"range_": range(20, 30)}),
       splits.SplitGenerator(
           name=splits.Split.TEST,
           num_shards=1,
           gen_kwargs={"range_": range(30, 40)}),
   ]
예제 #3
0
 def _split_generators(self, dl_manager):
     # Split the 30 examples from the generator into 2 train shards and 1 test
     # shard.
     del dl_manager
     return [
         splits_lib.SplitGenerator(
             name=splits_lib.Split.TRAIN,
             num_shards=2,
             gen_kwargs={"range_": range(20)},
         ),
         splits_lib.SplitGenerator(
             name=splits_lib.Split.TEST,
             num_shards=1,
             gen_kwargs={"range_": range(20, 30)},
         ),
     ]
예제 #4
0
 def _split_generators(self, _):
   return [
       splits_lib.SplitGenerator(
           name=splits_lib.Split.ALL,  # Error: ALL cannot be used as Split key
           num_shards=5,
       )
   ]
예제 #5
0
 def _split_generators(self, dl_manager):
     del dl_manager
     return [
         splits_lib.SplitGenerator(
             name=splits_lib.Split.TRAIN,
             gen_kwargs={},
         ),
     ]
    def _split_generators(self, dl_manager, pipeline):
        del dl_manager

        examples = (pipeline
                    | beam.Create(range(1000))
                    | beam.Map(_gen_example))

        return [
            splits_lib.SplitGenerator(
                name=splits_lib.Split.TRAIN,
                gen_kwargs=dict(examples=examples, num_examples=1000),
            ),
            splits_lib.SplitGenerator(
                name=splits_lib.Split.TEST,
                gen_kwargs=dict(examples=examples, num_examples=725),
            ),
        ]
예제 #7
0
 def _split_generators(self, dl_manager):
   # Split the 30 examples from the generator into 2 train shards and 1 test
   # shard.
   del dl_manager
   return [splits.SplitGenerator(
       name=[splits.Split.TRAIN, splits.Split.TEST],
       num_shards=[2, 1],
   )]
예제 #8
0
  def _split_generators(self, dl_manager):
    cifar_path = dl_manager.download_and_extract(self._cifar_info.url)
    cifar_info = self._cifar_info

    def gen_filenames(filenames):
      for f in filenames:
        yield os.path.join(cifar_path, self._cifar_info.prefix, f)

    return [
        splits.SplitGenerator(
            name=splits.Split.TRAIN,
            num_shards=10,
            gen_kwargs={"filepaths": gen_filenames(cifar_info.train_files)}),
        splits.SplitGenerator(
            name=splits.Split.TEST,
            num_shards=1,
            gen_kwargs={"filepaths": gen_filenames(cifar_info.test_files)}),
    ]
 def _split_generators(self, dl_manager):
   # Split the 30 examples from the generator into 2 train shards and 1 test
   # shard.
   del dl_manager
   return [
       splits_lib.SplitGenerator(
           name=splits_lib.Split.TRAIN,
           gen_kwargs={},
       ),
   ]
예제 #10
0
 def _split_generators(self, _):
     return [
         splits_lib.SplitGenerator(
             name="all",  # Error: ALL cannot be used as Split key
         )
     ]
예제 #11
0
 def _split_generators(self, dl_manager):
     return [
         splits.SplitGenerator(name=splits.Split.TRAIN, gen_kwargs=dict()),
         splits.SplitGenerator(name=splits.Split.TEST, gen_kwargs=dict()),
     ]