예제 #1
0
파일: dataset.py 프로젝트: licj1/metadl
    def set_fixed_episode_config(self):
        """ Set the episode description configuration. """
        if self.episode_config is not None:
            self.fixed_ways_shots = config.EpisodeDescriptionConfig(
                num_ways=self.episode_config[1],
                num_support=self.episode_config[2],
                num_query=self.episode_config[3],
                min_ways=self.episode_config[1],
                max_ways_upper_bound=self.episode_config[1],
                max_num_query=20,
                max_support_set_size=20,
                max_support_size_contrib_per_class=200,
                min_log_weight=-0.69314718055994529,
                max_log_weight=0.69314718055994529,
                ignore_dag_ontology=True,
                ignore_bilevel_ontology=True,
                min_examples_in_class=0)

        if self.valid_episode_config is not None:
            self.fixed_ways_shots_valid = config.EpisodeDescriptionConfig(
                num_ways=self.valid_episode_config[1],
                num_support=self.valid_episode_config[2],
                num_query=self.valid_episode_config[3],
                min_ways=self.valid_episode_config[1],
                max_ways_upper_bound=self.valid_episode_config[1],
                max_num_query=20,
                max_support_set_size=20,
                max_support_size_contrib_per_class=200,
                min_log_weight=-0.69314718055994529,
                max_log_weight=0.69314718055994529,
                ignore_dag_ontology=True,
                ignore_bilevel_ontology=True,
                min_examples_in_class=0)
예제 #2
0
    def __init__(self,
                 mode,
                 train_set=None,
                 validation_set=None,
                 test_set=None):
        super(MetaDatasetEpisodeReader,
              self).__init__(mode, train_set, validation_set, test_set)

        if mode == 'train':
            train_episode_desscription = config.EpisodeDescriptionConfig(
                None, None, None)
            self.train_dataset_next_task = self._init_multi_source_dataset(
                train_set, SPLIT_NAME_TO_SPLIT['train'],
                train_episode_desscription)

        if mode == 'val':
            test_episode_desscription = config.EpisodeDescriptionConfig(
                None, None, None)
            for item in validation_set:
                next_task = self._init_single_source_dataset(
                    item, SPLIT_NAME_TO_SPLIT['val'],
                    test_episode_desscription)
                self.validation_set_dict[item] = next_task

        if mode == 'test':
            test_episode_desscription = config.EpisodeDescriptionConfig(
                None, None, None)
            for item in test_set:
                next_task = self._init_single_source_dataset(
                    item, SPLIT_NAME_TO_SPLIT['test'],
                    test_episode_desscription)
                self.test_set_dict[item] = next_task
예제 #3
0
  def __init__(self, datasets, split, fixed_ways=None, fixed_support=None,
               fixed_query=None, use_ontology=True):
    assert split in ['train', 'valid', 'test']
    assert isinstance(datasets, (list, tuple))
    assert isinstance(datasets[0], str)

    print(f'Loading MetaDataset for {split}..')
    split = getattr(learning_spec.Split, split.upper())
    # Reading datasets
    # self.datasets = ['aircraft', 'cu_birds', 'dtd', 'fungi', 'ilsvrc_2012',
    # 'omniglot', 'quickdraw', 'vgg_flower']
    self.datasets = datasets

    # Ontology setting
    use_bilevel_ontology_list = []
    use_dag_ontology_list = []

    for dataset in self.datasets:
      bilevel = dag = False
      if dataset == 'omniglot' and use_ontology:
        bilevel = True
      elif dataset == 'ilsvrc_2012' and use_ontology:
        dag = True
      use_bilevel_ontology_list.append(bilevel)
      use_dag_ontology_list.append(dag)

    assert len(self.datasets) == len(use_bilevel_ontology_list)
    assert len(self.datasets) == len(use_dag_ontology_list)

    all_dataset_specs = []
    for dataset_name in self.datasets:
      dataset_records_path = os.path.join(BASE_PATH, dataset_name)
      dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path)
      all_dataset_specs.append(dataset_spec)

    if fixed_ways and use_ontology:
      max_ways_upper_bound = min_ways = fixed_ways
      self.episode_config = config.EpisodeDescriptionConfig(
          num_query=fixed_query, num_support=fixed_support, min_ways=min_ways,
          max_ways_upper_bound=max_ways_upper_bound, num_ways=None)
    else:
      # Episode description config (if num is None, use gin configuration)
      self.episode_config = config.EpisodeDescriptionConfig(
          num_query=fixed_query, num_support=fixed_support, num_ways=fixed_ways)

    # Episode pipeline
    self.episode_pipeline, self.n_total_classes = \
        pipeline.make_multisource_episode_pipeline2(
            dataset_spec_list=all_dataset_specs,
            use_dag_ontology_list=use_dag_ontology_list,
            use_bilevel_ontology_list=use_bilevel_ontology_list,
            episode_descr_config=self.episode_config,
            split=split, image_size=84)

    print('MetaDataset loaded: ', ', '.join([d for d in self.datasets]))
예제 #4
0
 def test_skip_too_many(self):
     # The "valid" split does not have MIN_WAYS (5) classes left if we skip some.
     with self.assertRaises(ValueError):
         sampling.EpisodeDescriptionSampler(
             self.dataset_spec, Split.VALID,
             config.EpisodeDescriptionConfig(
                 min_examples_in_class=self.min_examples_in_class))
예제 #5
0
 def test_too_many_ways(self):
     """Too many ways to have 1 example per class with default variable shots."""
     sampler = sampling.EpisodeDescriptionSampler(
         self.dataset_spec, Split.TRAIN,
         config.EpisodeDescriptionConfig(num_ways=600))
     with self.assertRaises(ValueError):
         sampler.sample_episode_description()
예제 #6
0
 def test_fixed_shots(self):
     sampler = sampling.EpisodeDescriptionSampler(
         self.dataset_spec,
         self.split,
         episode_descr_config=config.EpisodeDescriptionConfig(num_support=3,
                                                              num_query=7))
     self.check_expected_structure(sampler)
예제 #7
0
 def test_train(self):
     """Tests that a few episodes are consistent."""
     sampler = sampling.EpisodeDescriptionSampler(
         self.dataset_spec,
         Split.TRAIN,
         episode_descr_config=config.EpisodeDescriptionConfig())
     self.generate_and_check(sampler, 10)
예제 #8
0
 def make_sampler(self):
   return sampling.EpisodeDescriptionSampler(
       self.dataset_spec, self.split,
       config.EpisodeDescriptionConfig(
           num_ways=self.num_ways,
           num_support=self.num_support,
           num_query=self.num_query))
예제 #9
0
  def test_episode_switch_frequency(self):
    """Tests episode switch frequency."""
    num_episodes = 9
    switch_freq = 3
    num_ways = 5
    episode_descr_config = config.EpisodeDescriptionConfig()
    episode_descr_config.episode_description_switch_frequency = switch_freq
    episode_descr_config.num_ways = num_ways
    sampler = sampling.EpisodeDescriptionSampler(
        self.dataset_spec,
        self.split,
        episode_descr_config=episode_descr_config)
    episodes = self.generate_episodes(sampler, num_episodes)
    flush_size, _, _ = sampler.compute_chunk_sizes()
    # Group episodes every `switch_freq`.
    episode_group = [
        [episodes[0][0], episodes[1][0], episodes[2][0]],
        [episodes[3][0], episodes[4][0], episodes[5][0]],
        [episodes[6][0], episodes[7][0], episodes[8][0]],
    ]  # each episode is (input_string, class_id). We need only input_string.

    flush_size, _, _ = sampler.compute_chunk_sizes()

    def get_episode_classes(episode):
      return [e.split(b'.')[0] for e in episode[flush_size:]]

    for episodes in episode_group:
      ref_classes = get_episode_classes(episodes[0])
      for episode in episodes[1:]:
        self.assertAllEqual(ref_classes, get_episode_classes(episode))

    # The classes in different examples_group's should be different.
    ref_classes = get_episode_classes(episode_group[0][0])
    for episodes in episode_group[1:]:
      self.assertNotEqual(ref_classes, get_episode_classes(episodes[0]))
예제 #10
0
 def test_fixed_shots(self):
     sampler = sampling.EpisodeDescriptionSampler(
         self.dataset_spec,
         self.split,
         episode_descr_config=config.EpisodeDescriptionConfig(num_support=3,
                                                              num_query=7))
     self.generate_and_check(sampler, 10)
예제 #11
0
    def test_deterministic_noshuffle(self):
        """Tests episode generation determinism when there is noshuffle queue."""
        num_episodes = 10
        init_rng = sampling.RNG
        seed = 20181120
        episode_streams = []
        chunk_sizes = []
        try:
            for _ in range(2):
                sampling.RNG = np.random.RandomState(seed)
                sampler = sampling.EpisodeDescriptionSampler(
                    self.dataset_spec,
                    self.split,
                    episode_descr_config=config.EpisodeDescriptionConfig())
                episodes = self.generate_episodes(sampler,
                                                  num_episodes,
                                                  shuffle=False)
                episode_streams.append(episodes)
                chunk_size = sampler.compute_chunk_sizes()
                chunk_sizes.append(chunk_size)
                for examples, targets in episodes:
                    self.check_episode_consistency(examples, targets,
                                                   chunk_size)

        finally:
            # Restore the original RNG
            sampling.RNG = init_rng

        self.assertEqual(chunk_sizes[0], chunk_sizes[1])

        for ((examples1, targets1), (examples2,
                                     targets2)) in zip(*episode_streams):
            self.assertAllEqual(examples1, examples2)
            self.assertAllEqual(targets1, targets2)
예제 #12
0
 def test_shots_too_big(self):
     """Asserts failure if not enough examples to fulfill support and query."""
     sampler = sampling.EpisodeDescriptionSampler(
         self.dataset_spec, self.split,
         config.EpisodeDescriptionConfig(num_support=5, num_query=15))
     with self.assertRaises(ValueError):
         sampler.sample_episode_description()
예제 #13
0
 def test_ways_too_big(self):
     """Asserts failure if more ways than classes are available."""
     # Use Split.VALID as it only has 10 classes.
     sampler = sampling.EpisodeDescriptionSampler(
         self.dataset_spec, Split.VALID,
         config.EpisodeDescriptionConfig(num_ways=self.num_ways))
     with self.assertRaises(ValueError):
         sampler.sample_episode_description()
예제 #14
0
 def test_large_support(self):
     """Support set larger than MAX_SUPPORT_SET_SIZE with fixed shots."""
     sampler = sampling.EpisodeDescriptionSampler(
         self.dataset_spec, Split.TRAIN,
         config.EpisodeDescriptionConfig(num_ways=30, num_support=20))
     _, support_chunk_size, _ = sampler.compute_chunk_sizes()
     self.assertGreater(support_chunk_size, test_utils.MAX_SUPPORT_SET_SIZE)
     sampler.sample_episode_description()
예제 #15
0
    def __init__(self, data_path, mode, dataset, way, shot, query_train, query_test):

        self.data_path = data_path
        self.train_next_task = None
        self.validation_next_task = None
        self.test_next_task = None
        gin.parse_config_file('./meta_dataset_config.gin')

        fixed_way_shot_train = config.EpisodeDescriptionConfig(num_ways=way, num_support=shot, num_query=query_train)
        fixed_way_shot_test = config.EpisodeDescriptionConfig(num_ways=way, num_support=shot, num_query=query_test)

        if mode == 'train' or mode == 'train_test':
            self.train_next_task = self._init_dataset(dataset, learning_spec.Split.TRAIN, fixed_way_shot_train)
            self.validation_next_task = self._init_dataset(dataset, learning_spec.Split.VALID, fixed_way_shot_test)

        if mode == 'test' or mode == 'train_test' or mode == 'attack':
            self.test_next_task = self._init_dataset(dataset, learning_spec.Split.TEST, fixed_way_shot_test)
예제 #16
0
 def test_query_too_big(self):
     """Asserts failure if all examples of a class are selected for query."""
     sampler = sampling.EpisodeDescriptionSampler(
         self.dataset_spec, self.split,
         config.EpisodeDescriptionConfig(num_query=10))
     with self.assertRaises(ValueError):
         # Sample enough times that we encounter a class with only 10 examples.
         for _ in range(10):
             sampler.sample_episode_description()
예제 #17
0
 def test_large_query(self):
     """Query set larger than MAX_NUM_QUERY per class."""
     sampler = sampling.EpisodeDescriptionSampler(
         self.dataset_spec, Split.TRAIN,
         config.EpisodeDescriptionConfig(num_query=60))
     _, _, query_chunk_size = sampler.compute_chunk_sizes()
     self.assertGreater(
         query_chunk_size,
         test_utils.MAX_WAYS_UPPER_BOUND * test_utils.MAX_NUM_QUERY)
     sampler.sample_episode_description()
예제 #18
0
 def test_large_ways(self):
     """Fixed num_ways above MAX_WAYS_UPPER_BOUND."""
     sampler = sampling.EpisodeDescriptionSampler(
         self.dataset_spec, Split.TRAIN,
         config.EpisodeDescriptionConfig(num_ways=60, num_support=10))
     _, support_chunk_size, query_chunk_size = sampler.compute_chunk_sizes()
     self.assertGreater(support_chunk_size, test_utils.MAX_SUPPORT_SET_SIZE)
     self.assertGreater(
         query_chunk_size,
         test_utils.MAX_WAYS_UPPER_BOUND * test_utils.MAX_NUM_QUERY)
     sampler.sample_episode_description()
예제 #19
0
    def check_same_as_generator(self, split, offset):
        """Tests that the targets are the one requested by the generator.

    Args:
      split: A value of the Split enum, which split to generate from.
      offset: An int, the difference between the absolute class IDs in the
        source, and the relative class IDs in the episodes.
    """
        num_episodes = 10
        seed = 20181121
        init_rng = sampling.RNG
        try:
            sampling.RNG = np.random.RandomState(seed)
            sampler = sampling.EpisodeDescriptionSampler(
                self.dataset_spec,
                split,
                episode_descr_config=config.EpisodeDescriptionConfig())
            # Each description is a (class_id, num_support, num_query) tuple.
            descriptions = [
                sampler.sample_episode_description()
                for _ in range(num_episodes)
            ]

            sampling.RNG = np.random.RandomState(seed)
            sampler = sampling.EpisodeDescriptionSampler(
                self.dataset_spec,
                split,
                episode_descr_config=config.EpisodeDescriptionConfig())
            episodes = self.generate_episodes(sampler, num_episodes)
            chunk_sizes = sampler.compute_chunk_sizes()
            self.assertEqual(len(descriptions), len(episodes))
            for (description, episode) in zip(descriptions, episodes):
                examples, targets = episode
                self.check_episode_consistency(examples, targets, chunk_sizes)
                _, targets_support_chunk, targets_query_chunk = split_into_chunks(
                    targets, chunk_sizes)
                self.check_description_vs_target_chunks(
                    description, targets_support_chunk, targets_query_chunk,
                    offset)
        finally:
            sampling.RNG = init_rng
예제 #20
0
 def test_runtime_no_error(self, num_ways, num_support, num_query, kwargs):
     """Testing run-time errors thrown when arguments are not set correctly."""
     # The following scope removes the gin-config set.
     with gin.config_scope('none'):
         # No error thrown
         _ = sampling.EpisodeDescriptionSampler(
             self.dataset_spec,
             self.split,
             episode_descr_config=config.EpisodeDescriptionConfig(
                 num_ways=num_ways,
                 num_support=num_support,
                 num_query=num_query,
                 **kwargs))
예제 #21
0
  def test_meta_dataset(self):
    gin.parse_config_file(tfds.core.as_path(meta_dataset.__file__).parent /
                          'learn/gin/setups/data_config_tfds.gin')
    gin.parse_config(_DETERMINISTIC_CONFIG)
    data_config = config_lib.DataConfig()
    seed = 20210917
    num_episodes = 10
    meta_split = 'valid'
    md_sources = ('aircraft', 'cu_birds', 'vgg_flower')

    sampling.RNG = np.random.RandomState(seed)
    tfds_episode_dataset = api.meta_dataset(
        md_sources=md_sources, md_version='v1', meta_split=meta_split,
        source_sampling_seed=seed + 1, data_dir=FLAGS.tfds_path)
    tfds_episodes = list(
        tfds_episode_dataset.take(num_episodes).as_numpy_iterator())

    sampling.RNG = np.random.RandomState(seed)
    dataset_spec_list = [
        dataset_spec_lib.load_dataset_spec(os.path.join(FLAGS.meta_dataset_path,
                                                        md_source))
        for md_source in md_sources
    ]
    # We should not skip TFExample decoding in the original Meta-Dataset
    # implementation. The kwarg defaults to False when the class is defined, but
    # the call to `gin.parse_config_file` above changes the default value to
    # True, which is why we have to explicitly bind a new default value here.
    gin.bind_parameter('ImageDecoder.skip_tfexample_decoding', False)
    md_episode_dataset = pipeline.make_multisource_episode_pipeline(
        dataset_spec_list=dataset_spec_list,
        use_dag_ontology_list=['ilsvrc_2012' in dataset_spec.name
                               for dataset_spec in dataset_spec_list],
        use_bilevel_ontology_list=[dataset_spec.name == 'omniglot'
                                   for dataset_spec in dataset_spec_list],
        split=getattr(learning_spec.Split, meta_split.upper()),
        episode_descr_config=config_lib.EpisodeDescriptionConfig(),
        pool=None,
        shuffle_buffer_size=data_config.shuffle_buffer_size,
        image_size=data_config.image_height,
        source_sampling_seed=seed + 1
    )
    md_episodes = list(
        md_episode_dataset.take(num_episodes).as_numpy_iterator())

    for (tfds_episode, tfds_source_id), (md_episode, md_source_id) in zip(
        tfds_episodes, md_episodes):

      np.testing.assert_equal(tfds_source_id, md_source_id)

      for tfds_tensor, md_tensor in zip(tfds_episode, md_episode):
        np.testing.assert_allclose(tfds_tensor, md_tensor)
예제 #22
0
 def _get_test_episode_description(self):
     return config.EpisodeDescriptionConfig(
         num_ways=None,
         num_support=None,
         num_query=None,
         min_ways=5,
         max_ways_upper_bound=50,
         max_num_query=10,
         max_support_set_size=500,
         max_support_size_contrib_per_class=100,
         min_log_weight=-0.69314718055994529,  # np.log(0.5)
         max_log_weight=0.69314718055994529,  # np.log(2)
         ignore_dag_ontology=False,
         ignore_bilevel_ontology=False)
예제 #23
0
 def test_noskip_at_min(self):
   sampler = sampling.EpisodeDescriptionSampler(
       self.dataset_spec, self.split,
       config.EpisodeDescriptionConfig(min_examples_in_class=10))
   # We expect 10-example classes to be sampled at least some times
   for _ in range(10):
     episode_description = sampler.sample_episode_description()
     if any(cid % 3 == 0 for cid, _, _ in episode_description):
       # Test should pass
       break
     else:
       # The end of the loop was reached with no "break" triggered.
       # If no 10-example class is sampled after 10 iterations, it is an error.
       raise AssertionError('Classes with exactly `min_examples_in_class` '
                            'were not sampled.')
예제 #24
0
  def test_episodic_overfit(self,
                            learner_class,
                            learner_config,
                            threshold=1.,
                            attempts=1):
    """Test that error goes down when training on a single episode.

    This can help check that the trained model and the evaluated one share
    the trainable parameters correctly.

    Args:
      learner_class: A subclass of Learner.
      learner_config: A string, the Learner-specific gin configuration string.
      threshold: A float (default 1.), the performance to reach at least once.
      attempts: An int (default 1), how many of the last steps should be checked
        when looking for a validation value reaching the threshold (default 1).
    """
    gin_config = '\n'.join((self.BASE_GIN_CONFIG, learner_config))
    gin.parse_config(gin_config)

    episode_config = config.EpisodeDescriptionConfig(
        num_ways=self.NUM_EXAMPLES, num_support=1, num_query=1)

    trainer_instance = trainer.Trainer(
        train_learner_class=learner_class,
        eval_learner_class=learner_class,
        is_training=True,
        train_dataset_list=['dummy'],
        eval_dataset_list=['dummy'],
        records_root_dir=self.temp_dir,
        checkpoint_dir=os.path.join(self.temp_dir, 'checkpoints'),
        train_episode_config=episode_config,
        eval_episode_config=episode_config,
        data_config=config.DataConfig(),
        # BEGIN GOOGLE_INTERNAL
        real_episodes=False,
        real_episodes_results_dir='',
        # END GOOGLE_INTERNAL
    )
    # Train 1 update at a time for the last `attempts - 1` steps.
    trainer_instance.num_updates -= (attempts - 1)
    trainer_instance.train()
    valid_accs = [trainer_instance.valid_acc]
    for _ in range(attempts - 1):
      trainer_instance.num_updates += 1
      trainer_instance.train()
      valid_accs.append(trainer_instance.valid_acc)
    self.assertGreaterEqual(max(valid_accs), threshold)
예제 #25
0
 def _get_test_episode_description(self, num_ways, num_support, num_query):
     return config.EpisodeDescriptionConfig(
         num_ways=num_ways,
         num_support=num_support,
         num_query=num_query,
         min_ways=5,
         max_ways_upper_bound=50,
         max_num_query=10,
         max_support_set_size=500,
         max_support_size_contrib_per_class=100,
         min_log_weight=-0.69314718055994529,  # np.cnaps_layer_log.txt(0.5)
         max_log_weight=0.69314718055994529,  # np.cnaps_layer_log.txt(2)
         ignore_dag_ontology=False,
         ignore_bilevel_ontology=False,
         ignore_hierarchy_probability=0.0,
         simclr_episode_fraction=0.0)
예제 #26
0
def main(unused_argv):
    logging.info(FLAGS.output_dir)
    tf.io.gfile.makedirs(FLAGS.output_dir)
    gin.parse_config_files_and_bindings(FLAGS.gin_config,
                                        FLAGS.gin_bindings,
                                        finalize_config=True)
    dataset_spec = dataset_spec_lib.load_dataset_spec(
        os.path.join(FLAGS.records_root_dir, FLAGS.dataset_name))
    data_config = config.DataConfig()
    episode_descr_config = config.EpisodeDescriptionConfig()
    use_dag_ontology = (FLAGS.dataset_name in ('ilsvrc_2012', 'ilsvrc_2012_v2')
                        and not FLAGS.ignore_dag_ontology)
    use_bilevel_ontology = (FLAGS.dataset_name == 'omniglot'
                            and not FLAGS.ignore_bilevel_ontology)
    data_pipeline = pipeline.make_one_source_episode_pipeline(
        dataset_spec,
        use_dag_ontology=use_dag_ontology,
        use_bilevel_ontology=use_bilevel_ontology,
        split=FLAGS.split,
        episode_descr_config=episode_descr_config,
        # TODO(evcu) Maybe set the following to 0 to prevent shuffling and check
        # reproducibility of dumping.
        shuffle_buffer_size=data_config.shuffle_buffer_size,
        read_buffer_size_bytes=data_config.read_buffer_size_bytes,
        num_prefetch=data_config.num_prefetch)
    dataset = data_pipeline.take(FLAGS.num_episodes)

    images_per_class_dict = {}
    # Ignoring dataset number since we are loading one dataset.
    for episode_number, (episode, _) in enumerate(dataset):
        logging.info('Dumping episode %d', episode_number)
        train_imgs, train_labels, _, test_imgs, test_labels, _ = episode
        path_train = utils.get_file_path(FLAGS.output_dir, episode_number,
                                         'train')
        path_test = utils.get_file_path(FLAGS.output_dir, episode_number,
                                        'test')
        utils.dump_as_tfrecord(path_train, train_imgs, train_labels)
        utils.dump_as_tfrecord(path_test, test_imgs, test_labels)
        images_per_class_dict[os.path.basename(path_train)] = (
            utils.get_label_counts(train_labels))
        images_per_class_dict[os.path.basename(path_test)] = (
            utils.get_label_counts(test_labels))
    info_path = utils.get_info_path(FLAGS.output_dir)
    with tf.io.gfile.GFile(info_path, 'w') as f:
        f.write(json.dumps(images_per_class_dict, indent=2))
예제 #27
0
def make_md(lst, method, split='train', image_size=126, **kwargs):
    if split == 'train':
        SPLIT = learning_spec.Split.TRAIN
    elif split == 'val':
        SPLIT = learning_spec.Split.VALID
    elif split == 'test':
        SPLIT = learning_spec.Split.TEST

    ALL_DATASETS = lst
    all_dataset_specs = []
    for dataset_name in ALL_DATASETS:
        dataset_records_path = os.path.join(BASE_PATH, dataset_name)
        dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path)
        all_dataset_specs.append(dataset_spec)

    if method == 'episodic':
        use_bilevel_ontology_list = [False]*len(ALL_DATASETS)
        use_dag_ontology_list = [False]*len(ALL_DATASETS)
        # Enable ontology aware sampling for Omniglot and ImageNet. 
        for i, s in enumerate(ALL_DATASETS):
            if s == 'ilsvrc_2012':
                use_dag_ontology_list[i] = True
            if s == 'omniglot':
                use_bilevel_ontology_list[i] = True
        variable_ways_shots = config.EpisodeDescriptionConfig(
            num_query=None, num_support=None, num_ways=None)

        dataset_episodic = pipeline.make_multisource_episode_pipeline(
            dataset_spec_list=all_dataset_specs,
            use_dag_ontology_list=use_dag_ontology_list,
            use_bilevel_ontology_list=use_bilevel_ontology_list,
            episode_descr_config=variable_ways_shots,
            split=SPLIT, image_size=image_size)
	
        return dataset_episodic

    elif method == 'batch':
        BATCH_SIZE = kwargs['batch_size']
        ADD_DATASET_OFFSET = False
        dataset_batch = pipeline.make_multisource_batch_pipeline(
            dataset_spec_list=all_dataset_specs, batch_size=BATCH_SIZE, split=SPLIT,
            image_size=image_size, add_dataset_offset=ADD_DATASET_OFFSET)
        return dataset_batch
예제 #28
0
    def test_non_deterministic_shuffle(self):
        """Different Readers generate different episode compositions.

    Even with the same episode descriptions, the content should be different.
    """
        num_episodes = 10
        init_rng = sampling.RNG
        seed = 20181120
        episode_streams = []
        chunk_sizes = []
        try:
            for _ in range(2):
                sampling.RNG = np.random.RandomState(seed)
                sampler = sampling.EpisodeDescriptionSampler(
                    self.dataset_spec,
                    self.split,
                    episode_descr_config=config.EpisodeDescriptionConfig())
                episodes = self.generate_episodes(sampler, num_episodes)
                episode_streams.append(episodes)
                chunk_size = sampler.compute_chunk_sizes()
                chunk_sizes.append(chunk_size)
                for examples, targets in episodes:
                    self.check_episode_consistency(examples, targets,
                                                   chunk_size)

        finally:
            # Restore the original RNG
            sampling.RNG = init_rng

        self.assertEqual(chunk_sizes[0], chunk_sizes[1])

        # It is unlikely that all episodes will be the same
        num_identical_episodes = 0
        for ((examples1, targets1), (examples2,
                                     targets2)) in zip(*episode_streams):
            self.check_episode_consistency(examples1, targets1, chunk_sizes[0])
            self.check_episode_consistency(examples2, targets2, chunk_sizes[1])
            self.assertAllEqual(targets1, targets2)
            if all(examples1 == examples2):
                num_identical_episodes += 1

        self.assertNotEqual(num_identical_episodes, num_episodes)
예제 #29
0
    def test_flush_logic(self):
        """Tests the "flush" logic avoiding example duplication in an episode."""
        # Generate two episodes from un-shuffled data sources. For classes where
        # there are enough examples for both, new examples should be used for the
        # second episodes. Otherwise, the first examples should be re-used.
        # A data_spec with classes between 10 and 29 examples.
        num_classes = 30
        dataset_spec = DatasetSpecification(
            name=None,
            classes_per_split={
                Split.TRAIN: num_classes,
                Split.VALID: 0,
                Split.TEST: 0
            },
            images_per_class={i: 10 + i
                              for i in range(num_classes)},
            class_names=None,
            path=None,
            file_pattern='{}.tfrecords')
        # Sample from all train classes, 5 + 5 examples from each episode
        sampler = sampling.EpisodeDescriptionSampler(
            dataset_spec,
            Split.TRAIN,
            episode_descr_config=config.EpisodeDescriptionConfig(
                num_ways=num_classes, num_support=5, num_query=5))
        episodes = self.generate_episodes(sampler,
                                          num_episodes=2,
                                          shuffle=False)

        # The "flush" part of the second episode should contain 0 from class_id 0, 1
        # for 1, ..., 9 for 9, and then 0 for 10 and the following.
        chunk_sizes = sampler.compute_chunk_sizes()
        _, episode2 = episodes
        examples2, targets2 = episode2
        flush_target2, _, _ = split_into_chunks(targets2, chunk_sizes)
        for class_id in range(10):
            self.assertEqual(
                sum(target == class_id for target in flush_target2), class_id)
        for class_id in range(10, num_classes):
            self.assertEqual(
                sum(target == class_id for target in flush_target2), 0)

        # The "support" part of the second episode should start at example 0 for
        # class_ids from 0 to 9 (included), and at example 10 for class_id 10 and
        # higher.
        _, support_examples2, query_examples2 = split_into_chunks(
            examples2, chunk_sizes)

        def _build_class_id_to_example_ids(examples):
            # Build a mapping: class_id -> list of example ids
            mapping = collections.defaultdict(list)
            for example in examples:
                if not example:
                    # Padding is at the end
                    break
                class_id, example_id = example.decode().split('.')
                mapping[int(class_id)].append(int(example_id))
            return mapping

        support2_example_ids = _build_class_id_to_example_ids(
            support_examples2)
        query2_example_ids = _build_class_id_to_example_ids(query_examples2)

        for class_id in range(10):
            self.assertCountEqual(support2_example_ids[class_id],
                                  list(range(5)))
            self.assertCountEqual(query2_example_ids[class_id],
                                  list(range(5, 10)))

        for class_id in range(10, num_classes):
            self.assertCountEqual(support2_example_ids[class_id],
                                  list(range(10, 15)))
            self.assertCountEqual(query2_example_ids[class_id],
                                  list(range(15, 20)))
예제 #30
0
 def test_fixed_ways(self):
     sampler = sampling.EpisodeDescriptionSampler(
         self.dataset_spec,
         self.split,
         episode_descr_config=config.EpisodeDescriptionConfig(num_ways=12))
     self.generate_and_check(sampler, 10)