Esempio n. 1
0
    def test_threading_faster(self):
        sampler = sampling.EpisodeDescriptionSampler(self.dataset_spec,
                                                     self.split)
        transform = transforms.Compose([
            transforms.Lambda(lambda x: cv2.imdecode(x, -1)),
            transforms.ToPILImage(),
            transforms.Resize(84),
            transforms.ToTensor()
        ])
        dataset1 = EpisodicHDF5ClassDataset(self.dataset_spec,
                                            self.split,
                                            sampler=sampler,
                                            epoch_size=self.num_episodes * 10,
                                            image_size=84,
                                            transforms=transform,
                                            shuffle_seed=1234)

        dataloader1 = torch.utils.data.DataLoader(
            dataset1,
            1,
            num_workers=2,
            shuffle=True,
            worker_init_fn=dataset1.setup)

        threaded = 0
        t = time.time()
        for _ in dataloader1:
            threaded += time.time() - t
            t = time.time()

        logging.info("Threaded time %.03fs" % threaded)
        dataset1.setup()
        dataloader2 = torch.utils.data.DataLoader(
            dataset1,
            1,
            num_workers=0,
            shuffle=True,
        )
        nothreaded = 0
        t = time.time()
        for _ in dataloader2:
            nothreaded += time.time() - t
            t = time.time()

        logging.info("Sequential time %.03fs" % nothreaded)
        self.assertGreater(nothreaded, threaded)
Esempio n. 2
0
    def test_non_deterministic_shuffle(self):
        """Different Readers generate different episode compositions.

    Even with the same episode descriptions, the content should be different.
    """
        num_episodes = 10
        init_rng = sampling.RNG
        seed = 20181120
        episode_streams = []
        chunk_sizes = []
        try:
            for _ in range(2):
                sampling.RNG = np.random.RandomState(seed)
                sampler = sampling.EpisodeDescriptionSampler(
                    self.dataset_spec,
                    self.split,
                    episode_descr_config=config.EpisodeDescriptionConfig())
                episodes = self.generate_episodes(sampler, num_episodes)
                episode_streams.append(episodes)
                chunk_size = sampler.compute_chunk_sizes()
                chunk_sizes.append(chunk_size)
                for examples, targets in episodes:
                    self.check_episode_consistency(examples, targets,
                                                   chunk_size)

        finally:
            # Restore the original RNG
            sampling.RNG = init_rng

        self.assertEqual(chunk_sizes[0], chunk_sizes[1])

        # It is unlikely that all episodes will be the same
        num_identical_episodes = 0
        for ((examples1, targets1), (examples2,
                                     targets2)) in zip(*episode_streams):
            self.check_episode_consistency(examples1, targets1, chunk_sizes[0])
            self.check_episode_consistency(examples2, targets2, chunk_sizes[1])
            self.assertAllEqual(targets1, targets2)
            if all(examples1 == examples2):
                num_identical_episodes += 1

        self.assertNotEqual(num_identical_episodes, num_episodes)
Esempio n. 3
0
    def test_non_deterministic_shuffle(self):
        """Different Readers generate different episode compositions.

        Even with the same episode descriptions, the content should be different.
        """
        num_episodes = 10
        init_rng = sampling.RNG
        seed = 20181120
        episode_streams = []

        try:
            for _ in range(2):
                sampling.RNG = np.random.RandomState(seed)
                sampler = sampling.EpisodeDescriptionSampler(self.dataset_spec,
                                                             self.split)
                episodes = self.generate_episodes(sampler, num_episodes)
                episode_streams.append(episodes)
                for episode in episodes:
                    examples, targets = unpack_episode(episode)
                    self.check_episode_consistency(examples, targets)

        finally:
            # Restore the original RNG
            sampling.RNG = init_rng

        # It is unlikely that all episodes will be the same
        num_identical_episodes = 0
        for episode1, episode2 in zip(*episode_streams):
            examples1, targets1 = unpack_episode(episode1)
            examples2, targets2 = unpack_episode(episode2)
            self.check_episode_consistency(examples1, targets1)
            self.check_episode_consistency(examples2, targets2)
            np.testing.assert_array_equal(targets1, targets2)
            if np.equal(examples1, examples2).all():
                num_identical_episodes += 1

        self.assertNotEqual(num_identical_episodes, num_episodes)
Esempio n. 4
0
    def test_deterministic_tfseed(self):
        """Tests episode generation determinism when shuffle queues are seeded."""
        num_episodes = 10
        seed = 20181120
        episode_streams = []
        chunk_sizes = []
        init_rng = sampling.RNG
        try:
            for _ in range(2):
                sampling.RNG = np.random.RandomState(seed)
                sampler = sampling.EpisodeDescriptionSampler(
                    self.dataset_spec,
                    self.split,
                    episode_descr_config=config.EpisodeDescriptionConfig())
                episodes = self.generate_episodes(sampler,
                                                  num_episodes,
                                                  shuffle_seed=seed)
                episode_streams.append(episodes)
                chunk_size = sampler.compute_chunk_sizes()
                chunk_sizes.append(chunk_size)
                for examples, targets in episodes:
                    self.check_episode_consistency(examples, targets,
                                                   chunk_size)

        finally:
            # Restore the original RNG
            sampling.RNG = init_rng

        self.assertEqual(chunk_sizes[0], chunk_sizes[1])

        for ((examples1, targets1), (examples2,
                                     targets2)) in zip(*episode_streams):
            self.check_episode_consistency(examples1, targets1, chunk_sizes[0])
            self.check_episode_consistency(examples2, targets2, chunk_sizes[1])
            self.assertAllEqual(examples1, examples2)
            self.assertAllEqual(targets1, targets2)
 def make_sampler(self):
     return sampling.EpisodeDescriptionSampler(
         self.dataset_spec, self.split,
         config.EpisodeDescriptionConfig(num_ways=self.num_ways,
                                         num_support=self.num_support,
                                         num_query=self.num_query))
 def make_sampler(self):
     """Helper function to make a new instance of the tested sampler."""
     return sampling.EpisodeDescriptionSampler(
         self.dataset_spec, self.split, config.EpisodeDescriptionConfig())
 def make_sampler(self):
     return sampling.EpisodeDescriptionSampler(
         self.dataset_spec, self.split,
         config.EpisodeDescriptionConfig(num_ways=self.num_ways))
Esempio n. 8
0
 def test_valid(self):
     sampler = sampling.EpisodeDescriptionSampler(self.dataset_spec,
                                                  Split.VALID)
     self.generate_and_check(sampler, 10)
Esempio n. 9
0
 def test_too_many_ways(self):
   """Too many ways to have 1 example per class with default variable shots."""
   sampler = sampling.EpisodeDescriptionSampler(
       self.dataset_spec, Split.TRAIN, num_ways=600)
   with self.assertRaises(ValueError):
     sampler.sample_episode_description()
Esempio n. 10
0
 def test_no_query(self):
     sampler = sampling.EpisodeDescriptionSampler(
         self.dataset_spec,
         self.split,
         episode_descr_config=config.EpisodeDescriptionConfig(num_query=5))
     self.check_expected_structure(sampler)
Esempio n. 11
0
 def test_test(self):
     sampler = sampling.EpisodeDescriptionSampler(
         self.dataset_spec,
         Split.TEST,
         episode_descr_config=config.EpisodeDescriptionConfig())
     self.generate_and_check(sampler, 10)
Esempio n. 12
0
def make_one_source_episode_pipeline(dataset_spec,
                                     use_dag_ontology,
                                     use_bilevel_ontology,
                                     split,
                                     pool=None,
                                     num_ways=None,
                                     num_support=None,
                                     num_query=None,
                                     shuffle_buffer_size=None,
                                     read_buffer_size_bytes=None,
                                     image_size=None):
    """Returns a pipeline emitting data from one single source as Episodes.

  Args:
    dataset_spec: A DatasetSpecification object defining what to read from.
    use_dag_ontology: Whether to use source's ontology in the form of a DAG to
      sample episodes classes.
    use_bilevel_ontology: Whether to use source's bilevel ontology (consisting
      of superclasses and subclasses) to sample episode classes.
    split: A learning_spec.Split object identifying the source (meta-)split.
    pool: String (optional), for example-split datasets, which example split to
      use ('train', 'valid', or 'test'), used at meta-test time only.
    num_ways: Integer (optional), fixes the number of classes ("ways") to be
      used in each episode if provided.
    num_support: Integer (optional), fixes the number of examples for each class
      in the support set if provided.
    num_query: Integer (optional), fixes the number of examples for each class
      in the query set if provided.
    shuffle_buffer_size: int or None, shuffle buffer size for each Dataset.
    read_buffer_size_bytes: int or None, buffer size for each TFRecordDataset.
    image_size: int, desired image size used during decoding.

  Returns:
    A Dataset instance that outputs fully-assembled and decoded episodes.
  """
    if pool is not None:
        if not data.POOL_SUPPORTED:
            raise NotImplementedError(
                'Example-level splits or pools not supported.')
    else:
        use_all_classes = False
    episode_reader = reader.EpisodeReader(dataset_spec, split,
                                          shuffle_buffer_size,
                                          read_buffer_size_bytes)
    sampler = sampling.EpisodeDescriptionSampler(
        episode_reader.dataset_spec,
        split,
        pool=pool,
        use_dag_hierarchy=use_dag_ontology,
        use_bilevel_hierarchy=use_bilevel_ontology,
        use_all_classes=use_all_classes,
        num_ways=num_ways,
        num_support=num_support,
        num_query=num_query)
    dataset = episode_reader.create_dataset_input_pipeline(sampler, pool=pool)

    # Episodes coming out of `dataset` contain flushed examples and are internally
    # padded with dummy examples. `process_episode` discards flushed examples,
    # splits the episode into support and query sets, removes the dummy examples
    # and decodes the example strings.
    chunk_sizes = sampler.compute_chunk_sizes()
    map_fn = functools.partial(process_episode,
                               chunk_sizes=chunk_sizes,
                               image_size=image_size)
    dataset = dataset.map(map_fn)

    # Overlap episode processing and training.
    dataset = dataset.prefetch(1)
    return dataset
Esempio n. 13
0
def make_multisource_episode_pipeline(dataset_spec_list,
                                      use_dag_ontology_list,
                                      use_bilevel_ontology_list,
                                      split,
                                      pool=None,
                                      num_ways=None,
                                      num_support=None,
                                      num_query=None,
                                      shuffle_buffer_size=None,
                                      read_buffer_size_bytes=None,
                                      image_size=None):
    """Returns a pipeline emitting data from multiple sources as Episodes.

  Each episode only contains data from one single source. For each episode, its
  source is sampled uniformly across all sources.

  Args:
    dataset_spec_list: A list of DatasetSpecification, one for each source.
    use_dag_ontology_list: A list of Booleans, one for each source: whether to
      use that source's DAG-structured ontology to sample episode classes.
    use_bilevel_ontology_list: A list of Booleans, one for each source: whether
      to use that source's bi-level ontology to sample episode classes.
    split: A learning_spec.Split object identifying the sources split. It is the
      same for all datasets.
    pool: String (optional), for example-split datasets, which example split to
      use ('train', 'valid', or 'test'), used at meta-test time only.
    num_ways: Integer (optional), fixes the number of classes ("ways") to be
      used in each episode if provided.
    num_support: Integer (optional), fixes the number of examples for each class
      in the support set if provided.
    num_query: Integer (optional), fixes the number of examples for each class
      in the query set if provided.
    shuffle_buffer_size: int or None, shuffle buffer size for each Dataset.
    read_buffer_size_bytes: int or None, buffer size for each TFRecordDataset.
    image_size: int, desired image size used during decoding.

  Returns:
    A Dataset instance that outputs fully-assembled and decoded episodes.
  """
    if pool is not None:
        if not data.POOL_SUPPORTED:
            raise NotImplementedError(
                'Example-level splits or pools not supported.')
    sources = []
    for (dataset_spec, use_dag_ontology,
         use_bilevel_ontology) in zip(dataset_spec_list, use_dag_ontology_list,
                                      use_bilevel_ontology_list):
        episode_reader = reader.EpisodeReader(dataset_spec, split,
                                              shuffle_buffer_size,
                                              read_buffer_size_bytes)
        sampler = sampling.EpisodeDescriptionSampler(
            episode_reader.dataset_spec,
            split,
            pool=pool,
            use_dag_hierarchy=use_dag_ontology,
            use_bilevel_hierarchy=use_bilevel_ontology,
            num_ways=num_ways,
            num_support=num_support,
            num_query=num_query)
        dataset = episode_reader.create_dataset_input_pipeline(sampler,
                                                               pool=pool)
        sources.append(dataset)

    # Sample uniformly among sources
    dataset = tf.data.experimental.sample_from_datasets(sources)

    # Episodes coming out of `dataset` contain flushed examples and are internally
    # padded with dummy examples. `process_episode` discards flushed examples,
    # splits the episode into support and query sets, removes the dummy examples
    # and decodes the example strings.
    chunk_sizes = sampler.compute_chunk_sizes()
    map_fn = functools.partial(process_episode,
                               chunk_sizes=chunk_sizes,
                               image_size=image_size)
    dataset = dataset.map(map_fn)

    # Overlap episode processing and training.
    dataset = dataset.prefetch(1)
    return dataset
Esempio n. 14
0
 def test_fixed_ways(self):
     sampler = sampling.EpisodeDescriptionSampler(self.dataset_spec,
                                                  self.split,
                                                  num_ways=12)
     self.generate_and_check(sampler, 10)
Esempio n. 15
0
 def make_sampler(self):
     return sampling.EpisodeDescriptionSampler(
         self.dataset_spec, self.split,
         config.EpisodeDescriptionConfig(
             min_examples_in_class=self.min_examples_in_class))
Esempio n. 16
0
 def test_fixed_shots(self):
     sampler = sampling.EpisodeDescriptionSampler(self.dataset_spec,
                                                  self.split,
                                                  num_support=3,
                                                  num_query=7)
     self.generate_and_check(sampler, 10)
Esempio n. 17
0
 def test_no_query(self):
     sampler = sampling.EpisodeDescriptionSampler(self.dataset_spec,
                                                  self.split,
                                                  num_query=0)
     self.generate_and_check(sampler, 10)
Esempio n. 18
0
 def test_test(self):
     sampler = sampling.EpisodeDescriptionSampler(self.dataset_spec,
                                                  Split.TEST)
     self.generate_and_check(sampler, 10)
Esempio n. 19
0
def make_one_source_episode_pipeline(dataset_spec,
                                     use_dag_ontology,
                                     use_bilevel_ontology,
                                     split,
                                     episode_descr_config,
                                     pool=None,
                                     shuffle_buffer_size=None,
                                     read_buffer_size_bytes=None,
                                     num_prefetch=0,
                                     image_size=None,
                                     num_to_take=None):
  """Returns a pipeline emitting data from one single source as Episodes.

  Args:
    dataset_spec: A DatasetSpecification object defining what to read from.
    use_dag_ontology: Whether to use source's ontology in the form of a DAG to
      sample episodes classes.
    use_bilevel_ontology: Whether to use source's bilevel ontology (consisting
      of superclasses and subclasses) to sample episode classes.
    split: A learning_spec.Split object identifying the source (meta-)split.
    episode_descr_config: An instance of EpisodeDescriptionConfig containing
      parameters relating to sampling shots and ways for episodes.
    pool: String (optional), for example-split datasets, which example split to
      use ('train', 'valid', or 'test'), used at meta-test time only.
    shuffle_buffer_size: int or None, shuffle buffer size for each Dataset.
    read_buffer_size_bytes: int or None, buffer size for each TFRecordDataset.
    num_prefetch: int, the number of examples to prefetch for each class of each
      dataset. Prefetching occurs just after the class-specific Dataset object
      is constructed. If < 1, no prefetching occurs.
    image_size: int, desired image size used during decoding.
    num_to_take: Optional, an int specifying a number of elements to pick from
      each class' tfrecord. If specified, the available images of each class
      will be restricted to that int. By default no restriction is applied and
      all data is used.

  Returns:
    A Dataset instance that outputs fully-assembled and decoded episodes.
  """
  use_all_classes = False
  if pool is not None:
    if not data.POOL_SUPPORTED:
      raise NotImplementedError('Example-level splits or pools not supported.')
  if num_to_take is None:
    num_to_take = -1
  episode_reader = reader.EpisodeReader(dataset_spec, split,
                                        shuffle_buffer_size,
                                        read_buffer_size_bytes, num_prefetch,
                                        num_to_take)
  sampler = sampling.EpisodeDescriptionSampler(
      episode_reader.dataset_spec,
      split,
      episode_descr_config,
      pool=pool,
      use_dag_hierarchy=use_dag_ontology,
      use_bilevel_hierarchy=use_bilevel_ontology,
      use_all_classes=use_all_classes)
  dataset = episode_reader.create_dataset_input_pipeline(sampler, pool=pool)

  # Episodes coming out of `dataset` contain flushed examples and are internally
  # padded with dummy examples. `process_episode` discards flushed examples,
  # splits the episode into support and query sets, removes the dummy examples
  # and decodes the example strings.
  chunk_sizes = sampler.compute_chunk_sizes()
  map_fn = functools.partial(
      process_episode, chunk_sizes=chunk_sizes, image_size=image_size)
  dataset = dataset.map(map_fn)

  # Overlap episode processing and training.
  dataset = dataset.prefetch(1)
  return dataset
Esempio n. 20
0
def make_one_source_episode_pipeline(dataset_spec,
                                     use_dag_ontology,
                                     use_bilevel_ontology,
                                     split,
                                     episode_descr_config,
                                     pool=None,
                                     shuffle_buffer_size=None,
                                     read_buffer_size_bytes=None,
                                     num_prefetch=0,
                                     image_size=None,
                                     num_to_take=None,
                                     ignore_hierarchy_probability=0.0,
                                     simclr_episode_fraction=0.0):
    """Returns a pipeline emitting data from one single source as Episodes.

  Args:
    dataset_spec: A DatasetSpecification object defining what to read from.
    use_dag_ontology: Whether to use source's ontology in the form of a DAG to
      sample episodes classes.
    use_bilevel_ontology: Whether to use source's bilevel ontology (consisting
      of superclasses and subclasses) to sample episode classes.
    split: A learning_spec.Split object identifying the source (meta-)split.
    episode_descr_config: An instance of EpisodeDescriptionConfig containing
      parameters relating to sampling shots and ways for episodes.
    pool: String (optional), for example-split datasets, which example split to
      use ('train', 'valid', or 'test'), used at meta-test time only.
    shuffle_buffer_size: int or None, shuffle buffer size for each Dataset.
    read_buffer_size_bytes: int or None, buffer size for each TFRecordDataset.
    num_prefetch: int, the number of examples to prefetch for each class of each
      dataset. Prefetching occurs just after the class-specific Dataset object
      is constructed. If < 1, no prefetching occurs.
    image_size: int, desired image size used during decoding.
    num_to_take: Optional, an int specifying a number of elements to pick from
      each class' tfrecord. If specified, the available images of each class
      will be restricted to that int. By default no restriction is applied and
      all data is used.
    ignore_hierarchy_probability: Float, if using a hierarchy, this flag makes
      the sampler ignore the hierarchy for this proportion of episodes and
      instead sample categories uniformly.
    simclr_episode_fraction: Float, fraction of episodes that will be converted
      to SimCLR Episodes as described in the CrossTransformers paper.


  Returns:
    A Dataset instance that outputs tuples of fully-assembled and decoded
      episodes zipped with the ID of their data source of origin.
  """
    use_all_classes = False
    if pool is not None:
        if not data.POOL_SUPPORTED:
            raise NotImplementedError(
                'Example-level splits or pools not supported.')
    if num_to_take is None:
        num_to_take = -1

    num_unique_episodes = episode_descr_config.num_unique_episodes
    episode_reader = reader.EpisodeReader(dataset_spec, split,
                                          shuffle_buffer_size,
                                          read_buffer_size_bytes, num_prefetch,
                                          num_to_take, num_unique_episodes)
    sampler = sampling.EpisodeDescriptionSampler(
        episode_reader.dataset_spec,
        split,
        episode_descr_config,
        pool=pool,
        use_dag_hierarchy=use_dag_ontology,
        use_bilevel_hierarchy=use_bilevel_ontology,
        use_all_classes=use_all_classes,
        ignore_hierarchy_probability=ignore_hierarchy_probability)
    dataset = episode_reader.create_dataset_input_pipeline(sampler, pool=pool)
    # Episodes coming out of `dataset` contain flushed examples and are internally
    # padded with dummy examples. `process_episode` discards flushed examples,
    # splits the episode into support and query sets, removes the dummy examples
    # and decodes the example strings.
    chunk_sizes = sampler.compute_chunk_sizes()
    map_fn = functools.partial(process_episode,
                               chunk_sizes=chunk_sizes,
                               image_size=image_size,
                               simclr_episode_fraction=simclr_episode_fraction)
    dataset = dataset.map(map_fn)
    # There is only one data source, so we know that all episodes belong to it,
    # but for interface consistency, zip with a dataset identifying the source.
    source_id_dataset = tf.data.Dataset.from_tensors(0).repeat()
    dataset = tf.data.Dataset.zip((dataset, source_id_dataset))

    # Overlap episode processing and training.
    dataset = dataset.prefetch(1)
    return dataset
Esempio n. 21
0
def make_multisource_episode_pipeline(dataset_spec_list,
                                      use_dag_ontology_list,
                                      use_bilevel_ontology_list,
                                      split,
                                      episode_descr_config,
                                      pool=None,
                                      shuffle_buffer_size=None,
                                      read_buffer_size_bytes=None,
                                      num_prefetch=0,
                                      image_size=None,
                                      num_to_take=None):
  """Returns a pipeline emitting data from multiple sources as Episodes.

  Each episode only contains data from one single source. For each episode, its
  source is sampled uniformly across all sources.

  Args:
    dataset_spec_list: A list of DatasetSpecification, one for each source.
    use_dag_ontology_list: A list of Booleans, one for each source: whether to
      use that source's DAG-structured ontology to sample episode classes.
    use_bilevel_ontology_list: A list of Booleans, one for each source: whether
      to use that source's bi-level ontology to sample episode classes.
    split: A learning_spec.Split object identifying the sources split. It is the
      same for all datasets.
    episode_descr_config: An instance of EpisodeDescriptionConfig containing
      parameters relating to sampling shots and ways for episodes.
    pool: String (optional), for example-split datasets, which example split to
      use ('train', 'valid', or 'test'), used at meta-test time only.
    shuffle_buffer_size: int or None, shuffle buffer size for each Dataset.
    read_buffer_size_bytes: int or None, buffer size for each TFRecordDataset.
    num_prefetch: int, the number of examples to prefetch for each class of each
      dataset. Prefetching occurs just after the class-specific Dataset object
      is constructed. If < 1, no prefetching occurs.
    image_size: int, desired image size used during decoding.
    num_to_take: Optional, a list specifying for each dataset the number of
      examples per class to restrict to (for this given split). If provided, its
      length must be the same as len(dataset_spec). If None, no restrictions are
      applied to any dataset and all data per class is used.

  Returns:
    A Dataset instance that outputs fully-assembled and decoded episodes.
  """
  if pool is not None:
    if not data.POOL_SUPPORTED:
      raise NotImplementedError('Example-level splits or pools not supported.')
  if num_to_take is not None and len(num_to_take) != len(dataset_spec_list):
    raise ValueError('num_to_take does not have the same length as '
                     'dataset_spec_list.')
  if num_to_take is None:
    num_to_take = [-1] * len(dataset_spec_list)
  sources = []
  for (dataset_spec, use_dag_ontology, use_bilevel_ontology,
       num_to_take_for_dataset) in zip(dataset_spec_list, use_dag_ontology_list,
                                       use_bilevel_ontology_list, num_to_take):
    episode_reader = reader.EpisodeReader(dataset_spec, split,
                                          shuffle_buffer_size,
                                          read_buffer_size_bytes, num_prefetch,
                                          num_to_take_for_dataset)
    sampler = sampling.EpisodeDescriptionSampler(
        episode_reader.dataset_spec,
        split,
        episode_descr_config,
        pool=pool,
        use_dag_hierarchy=use_dag_ontology,
        use_bilevel_hierarchy=use_bilevel_ontology)
    dataset = episode_reader.create_dataset_input_pipeline(sampler, pool=pool)
    sources.append(dataset)

  # Sample uniformly among sources
  dataset = tf.data.experimental.sample_from_datasets(sources)

  # Episodes coming out of `dataset` contain flushed examples and are internally
  # padded with dummy examples. `process_episode` discards flushed examples,
  # splits the episode into support and query sets, removes the dummy examples
  # and decodes the example strings.
  chunk_sizes = sampler.compute_chunk_sizes()
  map_fn = functools.partial(
      process_episode, chunk_sizes=chunk_sizes, image_size=image_size)
  dataset = dataset.map(map_fn)

  # Overlap episode processing and training.
  dataset = dataset.prefetch(1)
  return dataset
Esempio n. 22
0
 def test_fixed_ways(self):
     sampler = sampling.EpisodeDescriptionSampler(
         self.dataset_spec,
         self.split,
         episode_descr_config=config.EpisodeDescriptionConfig(num_ways=12))
     self.generate_and_check(sampler, 10)
Esempio n. 23
0
 def test_fixed_ways(self):
     sampler = sampling.EpisodeDescriptionSampler(
         self.dataset_spec,
         self.split,
         episode_descr_config=config.EpisodeDescriptionConfig(num_ways=12))
     self.check_expected_structure(sampler)
Esempio n. 24
0
 def test_fixed_ways(self):
     sampler = sampling.EpisodeDescriptionSampler(self.dataset_spec,
                                                  self.split,
                                                  num_ways=12)
     self.check_expected_structure(sampler)
Esempio n. 25
0
 def test_no_query(self):
     sampler = sampling.EpisodeDescriptionSampler(
         self.dataset_spec,
         self.split,
         episode_descr_config=config.EpisodeDescriptionConfig(num_query=0))
     self.generate_and_check(sampler, 10)
Esempio n. 26
0
 def test_train(self):
     """Tests that a few episodes are consistent."""
     sampler = sampling.EpisodeDescriptionSampler(self.dataset_spec,
                                                  Split.TRAIN)
     self.generate_and_check(sampler, 10)
Esempio n. 27
0
 def make_sampler(self):
   return sampling.EpisodeDescriptionSampler(
       self.dataset_spec, self.split, num_query=self.num_query)
Esempio n. 28
0
 def test_fixed_shots(self):
     sampler = sampling.EpisodeDescriptionSampler(self.dataset_spec,
                                                  self.split,
                                                  num_support=3,
                                                  num_query=7)
     self.check_expected_structure(sampler)
Esempio n. 29
0
    def test_flush_logic(self):
        """Tests the "flush" logic avoiding example duplication in an episode."""
        # Generate two episodes from un-shuffled data sources. For classes where
        # there are enough examples for both, new examples should be used for the
        # second episodes. Otherwise, the first examples should be re-used.
        # A data_spec with classes between 10 and 29 examples.
        num_classes = 30
        dataset_spec = DatasetSpecification(
            name=None,
            classes_per_split={
                Split.TRAIN: num_classes,
                Split.VALID: 0,
                Split.TEST: 0
            },
            images_per_class={i: 10 + i
                              for i in range(num_classes)},
            class_names=None,
            path=None,
            file_pattern='{}.tfrecords')
        # Sample from all train classes, 5 + 5 examples from each episode
        sampler = sampling.EpisodeDescriptionSampler(
            dataset_spec,
            Split.TRAIN,
            episode_descr_config=config.EpisodeDescriptionConfig(
                num_ways=num_classes, num_support=5, num_query=5))
        episodes = self.generate_episodes(sampler,
                                          num_episodes=2,
                                          shuffle=False)

        # The "flush" part of the second episode should contain 0 from class_id 0, 1
        # for 1, ..., 9 for 9, and then 0 for 10 and the following.
        chunk_sizes = sampler.compute_chunk_sizes()
        _, episode2 = episodes
        examples2, targets2 = episode2
        flush_target2, _, _ = split_into_chunks(targets2, chunk_sizes)
        for class_id in range(10):
            self.assertEqual(
                sum(target == class_id for target in flush_target2), class_id)
        for class_id in range(10, num_classes):
            self.assertEqual(
                sum(target == class_id for target in flush_target2), 0)

        # The "support" part of the second episode should start at example 0 for
        # class_ids from 0 to 9 (included), and at example 10 for class_id 10 and
        # higher.
        _, support_examples2, query_examples2 = split_into_chunks(
            examples2, chunk_sizes)

        def _build_class_id_to_example_ids(examples):
            # Build a mapping: class_id -> list of example ids
            mapping = collections.defaultdict(list)
            for example in examples:
                if not example:
                    # Padding is at the end
                    break
                class_id, example_id = example.decode().split('.')
                mapping[int(class_id)].append(int(example_id))
            return mapping

        support2_example_ids = _build_class_id_to_example_ids(
            support_examples2)
        query2_example_ids = _build_class_id_to_example_ids(query_examples2)

        for class_id in range(10):
            self.assertCountEqual(support2_example_ids[class_id],
                                  list(range(5)))
            self.assertCountEqual(query2_example_ids[class_id],
                                  list(range(5, 10)))

        for class_id in range(10, num_classes):
            self.assertCountEqual(support2_example_ids[class_id],
                                  list(range(10, 15)))
            self.assertCountEqual(query2_example_ids[class_id],
                                  list(range(15, 20)))
Esempio n. 30
0
 def test_shots_too_big(self):
   """Asserts failure if not enough examples to fulfill support and query."""
   sampler = sampling.EpisodeDescriptionSampler(
       self.dataset_spec, self.split, num_support=5, num_query=15)
   with self.assertRaises(ValueError):
     sampler.sample_episode_description()