def set_fixed_episode_config(self): """ Set the episode description configuration. """ if self.episode_config is not None: self.fixed_ways_shots = config.EpisodeDescriptionConfig( num_ways=self.episode_config[1], num_support=self.episode_config[2], num_query=self.episode_config[3], min_ways=self.episode_config[1], max_ways_upper_bound=self.episode_config[1], max_num_query=20, max_support_set_size=20, max_support_size_contrib_per_class=200, min_log_weight=-0.69314718055994529, max_log_weight=0.69314718055994529, ignore_dag_ontology=True, ignore_bilevel_ontology=True, min_examples_in_class=0) if self.valid_episode_config is not None: self.fixed_ways_shots_valid = config.EpisodeDescriptionConfig( num_ways=self.valid_episode_config[1], num_support=self.valid_episode_config[2], num_query=self.valid_episode_config[3], min_ways=self.valid_episode_config[1], max_ways_upper_bound=self.valid_episode_config[1], max_num_query=20, max_support_set_size=20, max_support_size_contrib_per_class=200, min_log_weight=-0.69314718055994529, max_log_weight=0.69314718055994529, ignore_dag_ontology=True, ignore_bilevel_ontology=True, min_examples_in_class=0)
def __init__(self, mode, train_set=None, validation_set=None, test_set=None): super(MetaDatasetEpisodeReader, self).__init__(mode, train_set, validation_set, test_set) if mode == 'train': train_episode_desscription = config.EpisodeDescriptionConfig( None, None, None) self.train_dataset_next_task = self._init_multi_source_dataset( train_set, SPLIT_NAME_TO_SPLIT['train'], train_episode_desscription) if mode == 'val': test_episode_desscription = config.EpisodeDescriptionConfig( None, None, None) for item in validation_set: next_task = self._init_single_source_dataset( item, SPLIT_NAME_TO_SPLIT['val'], test_episode_desscription) self.validation_set_dict[item] = next_task if mode == 'test': test_episode_desscription = config.EpisodeDescriptionConfig( None, None, None) for item in test_set: next_task = self._init_single_source_dataset( item, SPLIT_NAME_TO_SPLIT['test'], test_episode_desscription) self.test_set_dict[item] = next_task
def __init__(self, datasets, split, fixed_ways=None, fixed_support=None, fixed_query=None, use_ontology=True): assert split in ['train', 'valid', 'test'] assert isinstance(datasets, (list, tuple)) assert isinstance(datasets[0], str) print(f'Loading MetaDataset for {split}..') split = getattr(learning_spec.Split, split.upper()) # Reading datasets # self.datasets = ['aircraft', 'cu_birds', 'dtd', 'fungi', 'ilsvrc_2012', # 'omniglot', 'quickdraw', 'vgg_flower'] self.datasets = datasets # Ontology setting use_bilevel_ontology_list = [] use_dag_ontology_list = [] for dataset in self.datasets: bilevel = dag = False if dataset == 'omniglot' and use_ontology: bilevel = True elif dataset == 'ilsvrc_2012' and use_ontology: dag = True use_bilevel_ontology_list.append(bilevel) use_dag_ontology_list.append(dag) assert len(self.datasets) == len(use_bilevel_ontology_list) assert len(self.datasets) == len(use_dag_ontology_list) all_dataset_specs = [] for dataset_name in self.datasets: dataset_records_path = os.path.join(BASE_PATH, dataset_name) dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path) all_dataset_specs.append(dataset_spec) if fixed_ways and use_ontology: max_ways_upper_bound = min_ways = fixed_ways self.episode_config = config.EpisodeDescriptionConfig( num_query=fixed_query, num_support=fixed_support, min_ways=min_ways, max_ways_upper_bound=max_ways_upper_bound, num_ways=None) else: # Episode description config (if num is None, use gin configuration) self.episode_config = config.EpisodeDescriptionConfig( num_query=fixed_query, num_support=fixed_support, num_ways=fixed_ways) # Episode pipeline self.episode_pipeline, self.n_total_classes = \ pipeline.make_multisource_episode_pipeline2( dataset_spec_list=all_dataset_specs, use_dag_ontology_list=use_dag_ontology_list, use_bilevel_ontology_list=use_bilevel_ontology_list, episode_descr_config=self.episode_config, split=split, image_size=84) print('MetaDataset loaded: ', ', '.join([d for d in self.datasets]))
def test_skip_too_many(self): # The "valid" split does not have MIN_WAYS (5) classes left if we skip some. with self.assertRaises(ValueError): sampling.EpisodeDescriptionSampler( self.dataset_spec, Split.VALID, config.EpisodeDescriptionConfig( min_examples_in_class=self.min_examples_in_class))
def test_too_many_ways(self): """Too many ways to have 1 example per class with default variable shots.""" sampler = sampling.EpisodeDescriptionSampler( self.dataset_spec, Split.TRAIN, config.EpisodeDescriptionConfig(num_ways=600)) with self.assertRaises(ValueError): sampler.sample_episode_description()
def test_fixed_shots(self): sampler = sampling.EpisodeDescriptionSampler( self.dataset_spec, self.split, episode_descr_config=config.EpisodeDescriptionConfig(num_support=3, num_query=7)) self.check_expected_structure(sampler)
def test_train(self): """Tests that a few episodes are consistent.""" sampler = sampling.EpisodeDescriptionSampler( self.dataset_spec, Split.TRAIN, episode_descr_config=config.EpisodeDescriptionConfig()) self.generate_and_check(sampler, 10)
def make_sampler(self): return sampling.EpisodeDescriptionSampler( self.dataset_spec, self.split, config.EpisodeDescriptionConfig( num_ways=self.num_ways, num_support=self.num_support, num_query=self.num_query))
def test_episode_switch_frequency(self): """Tests episode switch frequency.""" num_episodes = 9 switch_freq = 3 num_ways = 5 episode_descr_config = config.EpisodeDescriptionConfig() episode_descr_config.episode_description_switch_frequency = switch_freq episode_descr_config.num_ways = num_ways sampler = sampling.EpisodeDescriptionSampler( self.dataset_spec, self.split, episode_descr_config=episode_descr_config) episodes = self.generate_episodes(sampler, num_episodes) flush_size, _, _ = sampler.compute_chunk_sizes() # Group episodes every `switch_freq`. episode_group = [ [episodes[0][0], episodes[1][0], episodes[2][0]], [episodes[3][0], episodes[4][0], episodes[5][0]], [episodes[6][0], episodes[7][0], episodes[8][0]], ] # each episode is (input_string, class_id). We need only input_string. flush_size, _, _ = sampler.compute_chunk_sizes() def get_episode_classes(episode): return [e.split(b'.')[0] for e in episode[flush_size:]] for episodes in episode_group: ref_classes = get_episode_classes(episodes[0]) for episode in episodes[1:]: self.assertAllEqual(ref_classes, get_episode_classes(episode)) # The classes in different examples_group's should be different. ref_classes = get_episode_classes(episode_group[0][0]) for episodes in episode_group[1:]: self.assertNotEqual(ref_classes, get_episode_classes(episodes[0]))
def test_fixed_shots(self): sampler = sampling.EpisodeDescriptionSampler( self.dataset_spec, self.split, episode_descr_config=config.EpisodeDescriptionConfig(num_support=3, num_query=7)) self.generate_and_check(sampler, 10)
def test_deterministic_noshuffle(self): """Tests episode generation determinism when there is noshuffle queue.""" num_episodes = 10 init_rng = sampling.RNG seed = 20181120 episode_streams = [] chunk_sizes = [] try: for _ in range(2): sampling.RNG = np.random.RandomState(seed) sampler = sampling.EpisodeDescriptionSampler( self.dataset_spec, self.split, episode_descr_config=config.EpisodeDescriptionConfig()) episodes = self.generate_episodes(sampler, num_episodes, shuffle=False) episode_streams.append(episodes) chunk_size = sampler.compute_chunk_sizes() chunk_sizes.append(chunk_size) for examples, targets in episodes: self.check_episode_consistency(examples, targets, chunk_size) finally: # Restore the original RNG sampling.RNG = init_rng self.assertEqual(chunk_sizes[0], chunk_sizes[1]) for ((examples1, targets1), (examples2, targets2)) in zip(*episode_streams): self.assertAllEqual(examples1, examples2) self.assertAllEqual(targets1, targets2)
def test_shots_too_big(self): """Asserts failure if not enough examples to fulfill support and query.""" sampler = sampling.EpisodeDescriptionSampler( self.dataset_spec, self.split, config.EpisodeDescriptionConfig(num_support=5, num_query=15)) with self.assertRaises(ValueError): sampler.sample_episode_description()
def test_ways_too_big(self): """Asserts failure if more ways than classes are available.""" # Use Split.VALID as it only has 10 classes. sampler = sampling.EpisodeDescriptionSampler( self.dataset_spec, Split.VALID, config.EpisodeDescriptionConfig(num_ways=self.num_ways)) with self.assertRaises(ValueError): sampler.sample_episode_description()
def test_large_support(self): """Support set larger than MAX_SUPPORT_SET_SIZE with fixed shots.""" sampler = sampling.EpisodeDescriptionSampler( self.dataset_spec, Split.TRAIN, config.EpisodeDescriptionConfig(num_ways=30, num_support=20)) _, support_chunk_size, _ = sampler.compute_chunk_sizes() self.assertGreater(support_chunk_size, test_utils.MAX_SUPPORT_SET_SIZE) sampler.sample_episode_description()
def __init__(self, data_path, mode, dataset, way, shot, query_train, query_test): self.data_path = data_path self.train_next_task = None self.validation_next_task = None self.test_next_task = None gin.parse_config_file('./meta_dataset_config.gin') fixed_way_shot_train = config.EpisodeDescriptionConfig(num_ways=way, num_support=shot, num_query=query_train) fixed_way_shot_test = config.EpisodeDescriptionConfig(num_ways=way, num_support=shot, num_query=query_test) if mode == 'train' or mode == 'train_test': self.train_next_task = self._init_dataset(dataset, learning_spec.Split.TRAIN, fixed_way_shot_train) self.validation_next_task = self._init_dataset(dataset, learning_spec.Split.VALID, fixed_way_shot_test) if mode == 'test' or mode == 'train_test' or mode == 'attack': self.test_next_task = self._init_dataset(dataset, learning_spec.Split.TEST, fixed_way_shot_test)
def test_query_too_big(self): """Asserts failure if all examples of a class are selected for query.""" sampler = sampling.EpisodeDescriptionSampler( self.dataset_spec, self.split, config.EpisodeDescriptionConfig(num_query=10)) with self.assertRaises(ValueError): # Sample enough times that we encounter a class with only 10 examples. for _ in range(10): sampler.sample_episode_description()
def test_large_query(self): """Query set larger than MAX_NUM_QUERY per class.""" sampler = sampling.EpisodeDescriptionSampler( self.dataset_spec, Split.TRAIN, config.EpisodeDescriptionConfig(num_query=60)) _, _, query_chunk_size = sampler.compute_chunk_sizes() self.assertGreater( query_chunk_size, test_utils.MAX_WAYS_UPPER_BOUND * test_utils.MAX_NUM_QUERY) sampler.sample_episode_description()
def test_large_ways(self): """Fixed num_ways above MAX_WAYS_UPPER_BOUND.""" sampler = sampling.EpisodeDescriptionSampler( self.dataset_spec, Split.TRAIN, config.EpisodeDescriptionConfig(num_ways=60, num_support=10)) _, support_chunk_size, query_chunk_size = sampler.compute_chunk_sizes() self.assertGreater(support_chunk_size, test_utils.MAX_SUPPORT_SET_SIZE) self.assertGreater( query_chunk_size, test_utils.MAX_WAYS_UPPER_BOUND * test_utils.MAX_NUM_QUERY) sampler.sample_episode_description()
def check_same_as_generator(self, split, offset): """Tests that the targets are the one requested by the generator. Args: split: A value of the Split enum, which split to generate from. offset: An int, the difference between the absolute class IDs in the source, and the relative class IDs in the episodes. """ num_episodes = 10 seed = 20181121 init_rng = sampling.RNG try: sampling.RNG = np.random.RandomState(seed) sampler = sampling.EpisodeDescriptionSampler( self.dataset_spec, split, episode_descr_config=config.EpisodeDescriptionConfig()) # Each description is a (class_id, num_support, num_query) tuple. descriptions = [ sampler.sample_episode_description() for _ in range(num_episodes) ] sampling.RNG = np.random.RandomState(seed) sampler = sampling.EpisodeDescriptionSampler( self.dataset_spec, split, episode_descr_config=config.EpisodeDescriptionConfig()) episodes = self.generate_episodes(sampler, num_episodes) chunk_sizes = sampler.compute_chunk_sizes() self.assertEqual(len(descriptions), len(episodes)) for (description, episode) in zip(descriptions, episodes): examples, targets = episode self.check_episode_consistency(examples, targets, chunk_sizes) _, targets_support_chunk, targets_query_chunk = split_into_chunks( targets, chunk_sizes) self.check_description_vs_target_chunks( description, targets_support_chunk, targets_query_chunk, offset) finally: sampling.RNG = init_rng
def test_runtime_no_error(self, num_ways, num_support, num_query, kwargs): """Testing run-time errors thrown when arguments are not set correctly.""" # The following scope removes the gin-config set. with gin.config_scope('none'): # No error thrown _ = sampling.EpisodeDescriptionSampler( self.dataset_spec, self.split, episode_descr_config=config.EpisodeDescriptionConfig( num_ways=num_ways, num_support=num_support, num_query=num_query, **kwargs))
def test_meta_dataset(self): gin.parse_config_file(tfds.core.as_path(meta_dataset.__file__).parent / 'learn/gin/setups/data_config_tfds.gin') gin.parse_config(_DETERMINISTIC_CONFIG) data_config = config_lib.DataConfig() seed = 20210917 num_episodes = 10 meta_split = 'valid' md_sources = ('aircraft', 'cu_birds', 'vgg_flower') sampling.RNG = np.random.RandomState(seed) tfds_episode_dataset = api.meta_dataset( md_sources=md_sources, md_version='v1', meta_split=meta_split, source_sampling_seed=seed + 1, data_dir=FLAGS.tfds_path) tfds_episodes = list( tfds_episode_dataset.take(num_episodes).as_numpy_iterator()) sampling.RNG = np.random.RandomState(seed) dataset_spec_list = [ dataset_spec_lib.load_dataset_spec(os.path.join(FLAGS.meta_dataset_path, md_source)) for md_source in md_sources ] # We should not skip TFExample decoding in the original Meta-Dataset # implementation. The kwarg defaults to False when the class is defined, but # the call to `gin.parse_config_file` above changes the default value to # True, which is why we have to explicitly bind a new default value here. gin.bind_parameter('ImageDecoder.skip_tfexample_decoding', False) md_episode_dataset = pipeline.make_multisource_episode_pipeline( dataset_spec_list=dataset_spec_list, use_dag_ontology_list=['ilsvrc_2012' in dataset_spec.name for dataset_spec in dataset_spec_list], use_bilevel_ontology_list=[dataset_spec.name == 'omniglot' for dataset_spec in dataset_spec_list], split=getattr(learning_spec.Split, meta_split.upper()), episode_descr_config=config_lib.EpisodeDescriptionConfig(), pool=None, shuffle_buffer_size=data_config.shuffle_buffer_size, image_size=data_config.image_height, source_sampling_seed=seed + 1 ) md_episodes = list( md_episode_dataset.take(num_episodes).as_numpy_iterator()) for (tfds_episode, tfds_source_id), (md_episode, md_source_id) in zip( tfds_episodes, md_episodes): np.testing.assert_equal(tfds_source_id, md_source_id) for tfds_tensor, md_tensor in zip(tfds_episode, md_episode): np.testing.assert_allclose(tfds_tensor, md_tensor)
def _get_test_episode_description(self): return config.EpisodeDescriptionConfig( num_ways=None, num_support=None, num_query=None, min_ways=5, max_ways_upper_bound=50, max_num_query=10, max_support_set_size=500, max_support_size_contrib_per_class=100, min_log_weight=-0.69314718055994529, # np.log(0.5) max_log_weight=0.69314718055994529, # np.log(2) ignore_dag_ontology=False, ignore_bilevel_ontology=False)
def test_noskip_at_min(self): sampler = sampling.EpisodeDescriptionSampler( self.dataset_spec, self.split, config.EpisodeDescriptionConfig(min_examples_in_class=10)) # We expect 10-example classes to be sampled at least some times for _ in range(10): episode_description = sampler.sample_episode_description() if any(cid % 3 == 0 for cid, _, _ in episode_description): # Test should pass break else: # The end of the loop was reached with no "break" triggered. # If no 10-example class is sampled after 10 iterations, it is an error. raise AssertionError('Classes with exactly `min_examples_in_class` ' 'were not sampled.')
def test_episodic_overfit(self, learner_class, learner_config, threshold=1., attempts=1): """Test that error goes down when training on a single episode. This can help check that the trained model and the evaluated one share the trainable parameters correctly. Args: learner_class: A subclass of Learner. learner_config: A string, the Learner-specific gin configuration string. threshold: A float (default 1.), the performance to reach at least once. attempts: An int (default 1), how many of the last steps should be checked when looking for a validation value reaching the threshold (default 1). """ gin_config = '\n'.join((self.BASE_GIN_CONFIG, learner_config)) gin.parse_config(gin_config) episode_config = config.EpisodeDescriptionConfig( num_ways=self.NUM_EXAMPLES, num_support=1, num_query=1) trainer_instance = trainer.Trainer( train_learner_class=learner_class, eval_learner_class=learner_class, is_training=True, train_dataset_list=['dummy'], eval_dataset_list=['dummy'], records_root_dir=self.temp_dir, checkpoint_dir=os.path.join(self.temp_dir, 'checkpoints'), train_episode_config=episode_config, eval_episode_config=episode_config, data_config=config.DataConfig(), # BEGIN GOOGLE_INTERNAL real_episodes=False, real_episodes_results_dir='', # END GOOGLE_INTERNAL ) # Train 1 update at a time for the last `attempts - 1` steps. trainer_instance.num_updates -= (attempts - 1) trainer_instance.train() valid_accs = [trainer_instance.valid_acc] for _ in range(attempts - 1): trainer_instance.num_updates += 1 trainer_instance.train() valid_accs.append(trainer_instance.valid_acc) self.assertGreaterEqual(max(valid_accs), threshold)
def _get_test_episode_description(self, num_ways, num_support, num_query): return config.EpisodeDescriptionConfig( num_ways=num_ways, num_support=num_support, num_query=num_query, min_ways=5, max_ways_upper_bound=50, max_num_query=10, max_support_set_size=500, max_support_size_contrib_per_class=100, min_log_weight=-0.69314718055994529, # np.cnaps_layer_log.txt(0.5) max_log_weight=0.69314718055994529, # np.cnaps_layer_log.txt(2) ignore_dag_ontology=False, ignore_bilevel_ontology=False, ignore_hierarchy_probability=0.0, simclr_episode_fraction=0.0)
def main(unused_argv): logging.info(FLAGS.output_dir) tf.io.gfile.makedirs(FLAGS.output_dir) gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings, finalize_config=True) dataset_spec = dataset_spec_lib.load_dataset_spec( os.path.join(FLAGS.records_root_dir, FLAGS.dataset_name)) data_config = config.DataConfig() episode_descr_config = config.EpisodeDescriptionConfig() use_dag_ontology = (FLAGS.dataset_name in ('ilsvrc_2012', 'ilsvrc_2012_v2') and not FLAGS.ignore_dag_ontology) use_bilevel_ontology = (FLAGS.dataset_name == 'omniglot' and not FLAGS.ignore_bilevel_ontology) data_pipeline = pipeline.make_one_source_episode_pipeline( dataset_spec, use_dag_ontology=use_dag_ontology, use_bilevel_ontology=use_bilevel_ontology, split=FLAGS.split, episode_descr_config=episode_descr_config, # TODO(evcu) Maybe set the following to 0 to prevent shuffling and check # reproducibility of dumping. shuffle_buffer_size=data_config.shuffle_buffer_size, read_buffer_size_bytes=data_config.read_buffer_size_bytes, num_prefetch=data_config.num_prefetch) dataset = data_pipeline.take(FLAGS.num_episodes) images_per_class_dict = {} # Ignoring dataset number since we are loading one dataset. for episode_number, (episode, _) in enumerate(dataset): logging.info('Dumping episode %d', episode_number) train_imgs, train_labels, _, test_imgs, test_labels, _ = episode path_train = utils.get_file_path(FLAGS.output_dir, episode_number, 'train') path_test = utils.get_file_path(FLAGS.output_dir, episode_number, 'test') utils.dump_as_tfrecord(path_train, train_imgs, train_labels) utils.dump_as_tfrecord(path_test, test_imgs, test_labels) images_per_class_dict[os.path.basename(path_train)] = ( utils.get_label_counts(train_labels)) images_per_class_dict[os.path.basename(path_test)] = ( utils.get_label_counts(test_labels)) info_path = utils.get_info_path(FLAGS.output_dir) with tf.io.gfile.GFile(info_path, 'w') as f: f.write(json.dumps(images_per_class_dict, indent=2))
def make_md(lst, method, split='train', image_size=126, **kwargs): if split == 'train': SPLIT = learning_spec.Split.TRAIN elif split == 'val': SPLIT = learning_spec.Split.VALID elif split == 'test': SPLIT = learning_spec.Split.TEST ALL_DATASETS = lst all_dataset_specs = [] for dataset_name in ALL_DATASETS: dataset_records_path = os.path.join(BASE_PATH, dataset_name) dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path) all_dataset_specs.append(dataset_spec) if method == 'episodic': use_bilevel_ontology_list = [False]*len(ALL_DATASETS) use_dag_ontology_list = [False]*len(ALL_DATASETS) # Enable ontology aware sampling for Omniglot and ImageNet. for i, s in enumerate(ALL_DATASETS): if s == 'ilsvrc_2012': use_dag_ontology_list[i] = True if s == 'omniglot': use_bilevel_ontology_list[i] = True variable_ways_shots = config.EpisodeDescriptionConfig( num_query=None, num_support=None, num_ways=None) dataset_episodic = pipeline.make_multisource_episode_pipeline( dataset_spec_list=all_dataset_specs, use_dag_ontology_list=use_dag_ontology_list, use_bilevel_ontology_list=use_bilevel_ontology_list, episode_descr_config=variable_ways_shots, split=SPLIT, image_size=image_size) return dataset_episodic elif method == 'batch': BATCH_SIZE = kwargs['batch_size'] ADD_DATASET_OFFSET = False dataset_batch = pipeline.make_multisource_batch_pipeline( dataset_spec_list=all_dataset_specs, batch_size=BATCH_SIZE, split=SPLIT, image_size=image_size, add_dataset_offset=ADD_DATASET_OFFSET) return dataset_batch
def test_non_deterministic_shuffle(self): """Different Readers generate different episode compositions. Even with the same episode descriptions, the content should be different. """ num_episodes = 10 init_rng = sampling.RNG seed = 20181120 episode_streams = [] chunk_sizes = [] try: for _ in range(2): sampling.RNG = np.random.RandomState(seed) sampler = sampling.EpisodeDescriptionSampler( self.dataset_spec, self.split, episode_descr_config=config.EpisodeDescriptionConfig()) episodes = self.generate_episodes(sampler, num_episodes) episode_streams.append(episodes) chunk_size = sampler.compute_chunk_sizes() chunk_sizes.append(chunk_size) for examples, targets in episodes: self.check_episode_consistency(examples, targets, chunk_size) finally: # Restore the original RNG sampling.RNG = init_rng self.assertEqual(chunk_sizes[0], chunk_sizes[1]) # It is unlikely that all episodes will be the same num_identical_episodes = 0 for ((examples1, targets1), (examples2, targets2)) in zip(*episode_streams): self.check_episode_consistency(examples1, targets1, chunk_sizes[0]) self.check_episode_consistency(examples2, targets2, chunk_sizes[1]) self.assertAllEqual(targets1, targets2) if all(examples1 == examples2): num_identical_episodes += 1 self.assertNotEqual(num_identical_episodes, num_episodes)
def test_flush_logic(self): """Tests the "flush" logic avoiding example duplication in an episode.""" # Generate two episodes from un-shuffled data sources. For classes where # there are enough examples for both, new examples should be used for the # second episodes. Otherwise, the first examples should be re-used. # A data_spec with classes between 10 and 29 examples. num_classes = 30 dataset_spec = DatasetSpecification( name=None, classes_per_split={ Split.TRAIN: num_classes, Split.VALID: 0, Split.TEST: 0 }, images_per_class={i: 10 + i for i in range(num_classes)}, class_names=None, path=None, file_pattern='{}.tfrecords') # Sample from all train classes, 5 + 5 examples from each episode sampler = sampling.EpisodeDescriptionSampler( dataset_spec, Split.TRAIN, episode_descr_config=config.EpisodeDescriptionConfig( num_ways=num_classes, num_support=5, num_query=5)) episodes = self.generate_episodes(sampler, num_episodes=2, shuffle=False) # The "flush" part of the second episode should contain 0 from class_id 0, 1 # for 1, ..., 9 for 9, and then 0 for 10 and the following. chunk_sizes = sampler.compute_chunk_sizes() _, episode2 = episodes examples2, targets2 = episode2 flush_target2, _, _ = split_into_chunks(targets2, chunk_sizes) for class_id in range(10): self.assertEqual( sum(target == class_id for target in flush_target2), class_id) for class_id in range(10, num_classes): self.assertEqual( sum(target == class_id for target in flush_target2), 0) # The "support" part of the second episode should start at example 0 for # class_ids from 0 to 9 (included), and at example 10 for class_id 10 and # higher. _, support_examples2, query_examples2 = split_into_chunks( examples2, chunk_sizes) def _build_class_id_to_example_ids(examples): # Build a mapping: class_id -> list of example ids mapping = collections.defaultdict(list) for example in examples: if not example: # Padding is at the end break class_id, example_id = example.decode().split('.') mapping[int(class_id)].append(int(example_id)) return mapping support2_example_ids = _build_class_id_to_example_ids( support_examples2) query2_example_ids = _build_class_id_to_example_ids(query_examples2) for class_id in range(10): self.assertCountEqual(support2_example_ids[class_id], list(range(5))) self.assertCountEqual(query2_example_ids[class_id], list(range(5, 10))) for class_id in range(10, num_classes): self.assertCountEqual(support2_example_ids[class_id], list(range(10, 15))) self.assertCountEqual(query2_example_ids[class_id], list(range(15, 20)))
def test_fixed_ways(self): sampler = sampling.EpisodeDescriptionSampler( self.dataset_spec, self.split, episode_descr_config=config.EpisodeDescriptionConfig(num_ways=12)) self.generate_and_check(sampler, 10)