def _get_dataset_spec(self, items): if isinstance(items, list): dataset_specs = [] for dataset_name in items: dataset_records_path = os.path.join(self.data_path, dataset_name) dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path) dataset_specs.append(dataset_spec) return dataset_specs else: dataset_name = items dataset_records_path = os.path.join(self.data_path, dataset_name) dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path) return dataset_spec
def get_synsets_from_class_ids(class_ids): """Returns the Synsets of the appropriate subgraph corresponding to class_ids. For each class id in class_ids, the corresponding Synset is found among the Synsets of the subgraph corresponding to the split that is chosen for the fine-grainedness analysis. Args: class_ids: A np.array of ints in the range between 1 and the total number of classes that contains the two class id's chosen for an episode. Returns: A list of Synsets. Raises: ValueError: The dataset specification is not found in the expected location. """ # First load the DatasetSpecification of ImageNet. dataset_records_path = os.path.join(FLAGS.records_root_dir, 'ilsvrc_2012') imagenet_data_spec = dataset_spec.load_dataset_spec(dataset_records_path) # A set of Synsets of the split's subgraph. split_enum = get_finegrainedness_split_enum() split_subgraph = imagenet_data_spec.split_subgraphs[split_enum] # Go from class_ids (integers in the range from 1 to the total number of # classes in the Split) to WordNet id's, e.g n02075296. wn_ids = [] for class_id in class_ids: wn_ids.append(imagenet_data_spec.class_names[class_id]) # Find the Synsets in split_subgraph whose WordNet id's are wn_ids. synsets = imagenet_spec.get_synsets_from_ids(wn_ids, split_subgraph) return [synsets[wn_id] for wn_id in wn_ids]
def _init_multi_source_dataset(self, items, split, episode_description): dataset_specs = [] for dataset_name in items: dataset_records_path = os.path.join(self.data_path, dataset_name) dataset_spec = dataset_spec_lib.load_dataset_spec( dataset_records_path) dataset_specs.append(dataset_spec) use_bilevel_ontology_list = [False] * len(items) use_dag_ontology_list = [False] * len(items) # Enable ontology aware sampling for Omniglot and ImageNet. if 'omniglot' in items: use_bilevel_ontology_list[items.index('omniglot')] = True if 'ilsvrc_2012' in items: use_dag_ontology_list[items.index('ilsvrc_2012')] = True multi_source_pipeline = pipeline.make_multisource_episode_pipeline( dataset_spec_list=dataset_specs, use_dag_ontology_list=use_dag_ontology_list, use_bilevel_ontology_list=use_bilevel_ontology_list, split=split, episode_descr_config=episode_description, image_size=84) iterator = multi_source_pipeline.make_one_shot_iterator() return iterator.get_next()
def _init_single_source_dataset(self, dataset_name, split, episode_description): dataset_records_path = os.path.join(self.data_path, dataset_name) dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path) # Enable ontology aware sampling for Omniglot and ImageNet. use_bilevel_ontology = False if 'omniglot' in dataset_name: use_bilevel_ontology = True use_dag_ontology = False if 'ilsvrc_2012' in dataset_name: use_dag_ontology = True single_source_pipeline = pipeline.make_one_source_episode_pipeline( dataset_spec=dataset_spec, use_dag_ontology=use_dag_ontology, use_bilevel_ontology=use_bilevel_ontology, split=split, episode_descr_config=episode_description, image_size=84, shuffle_buffer_size=1000) iterator = single_source_pipeline.make_one_shot_iterator() return iterator.get_next()
def __init__(self, datasets, split, fixed_ways=None, fixed_support=None, fixed_query=None, use_ontology=True): assert split in ['train', 'valid', 'test'] assert isinstance(datasets, (list, tuple)) assert isinstance(datasets[0], str) print(f'Loading MetaDataset for {split}..') split = getattr(learning_spec.Split, split.upper()) # Reading datasets # self.datasets = ['aircraft', 'cu_birds', 'dtd', 'fungi', 'ilsvrc_2012', # 'omniglot', 'quickdraw', 'vgg_flower'] self.datasets = datasets # Ontology setting use_bilevel_ontology_list = [] use_dag_ontology_list = [] for dataset in self.datasets: bilevel = dag = False if dataset == 'omniglot' and use_ontology: bilevel = True elif dataset == 'ilsvrc_2012' and use_ontology: dag = True use_bilevel_ontology_list.append(bilevel) use_dag_ontology_list.append(dag) assert len(self.datasets) == len(use_bilevel_ontology_list) assert len(self.datasets) == len(use_dag_ontology_list) all_dataset_specs = [] for dataset_name in self.datasets: dataset_records_path = os.path.join(BASE_PATH, dataset_name) dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path) all_dataset_specs.append(dataset_spec) if fixed_ways and use_ontology: max_ways_upper_bound = min_ways = fixed_ways self.episode_config = config.EpisodeDescriptionConfig( num_query=fixed_query, num_support=fixed_support, min_ways=min_ways, max_ways_upper_bound=max_ways_upper_bound, num_ways=None) else: # Episode description config (if num is None, use gin configuration) self.episode_config = config.EpisodeDescriptionConfig( num_query=fixed_query, num_support=fixed_support, num_ways=fixed_ways) # Episode pipeline self.episode_pipeline, self.n_total_classes = \ pipeline.make_multisource_episode_pipeline2( dataset_spec_list=all_dataset_specs, use_dag_ontology_list=use_dag_ontology_list, use_bilevel_ontology_list=use_bilevel_ontology_list, episode_descr_config=self.episode_config, split=split, image_size=84) print('MetaDataset loaded: ', ', '.join([d for d in self.datasets]))
def test_meta_dataset(self): gin.parse_config_file(tfds.core.as_path(meta_dataset.__file__).parent / 'learn/gin/setups/data_config_tfds.gin') gin.parse_config(_DETERMINISTIC_CONFIG) data_config = config_lib.DataConfig() seed = 20210917 num_episodes = 10 meta_split = 'valid' md_sources = ('aircraft', 'cu_birds', 'vgg_flower') sampling.RNG = np.random.RandomState(seed) tfds_episode_dataset = api.meta_dataset( md_sources=md_sources, md_version='v1', meta_split=meta_split, source_sampling_seed=seed + 1, data_dir=FLAGS.tfds_path) tfds_episodes = list( tfds_episode_dataset.take(num_episodes).as_numpy_iterator()) sampling.RNG = np.random.RandomState(seed) dataset_spec_list = [ dataset_spec_lib.load_dataset_spec(os.path.join(FLAGS.meta_dataset_path, md_source)) for md_source in md_sources ] # We should not skip TFExample decoding in the original Meta-Dataset # implementation. The kwarg defaults to False when the class is defined, but # the call to `gin.parse_config_file` above changes the default value to # True, which is why we have to explicitly bind a new default value here. gin.bind_parameter('ImageDecoder.skip_tfexample_decoding', False) md_episode_dataset = pipeline.make_multisource_episode_pipeline( dataset_spec_list=dataset_spec_list, use_dag_ontology_list=['ilsvrc_2012' in dataset_spec.name for dataset_spec in dataset_spec_list], use_bilevel_ontology_list=[dataset_spec.name == 'omniglot' for dataset_spec in dataset_spec_list], split=getattr(learning_spec.Split, meta_split.upper()), episode_descr_config=config_lib.EpisodeDescriptionConfig(), pool=None, shuffle_buffer_size=data_config.shuffle_buffer_size, image_size=data_config.image_height, source_sampling_seed=seed + 1 ) md_episodes = list( md_episode_dataset.take(num_episodes).as_numpy_iterator()) for (tfds_episode, tfds_source_id), (md_episode, md_source_id) in zip( tfds_episodes, md_episodes): np.testing.assert_equal(tfds_source_id, md_source_id) for tfds_tensor, md_tensor in zip(tfds_episode, md_episode): np.testing.assert_allclose(tfds_tensor, md_tensor)
def _init_dataset(self, dataset, split, episode_description): dataset_records_path = os.path.join(self.data_path, dataset) dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path) single_source_pipeline = pipeline.make_one_source_episode_pipeline( dataset_spec=dataset_spec, use_dag_ontology=False, use_bilevel_ontology=False, split=split, episode_descr_config=episode_description, image_size=84) iterator = single_source_pipeline.make_one_shot_iterator() return iterator.get_next()
def main(unused_argv): logging.info(FLAGS.output_dir) tf.io.gfile.makedirs(FLAGS.output_dir) gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings, finalize_config=True) dataset_spec = dataset_spec_lib.load_dataset_spec( os.path.join(FLAGS.records_root_dir, FLAGS.dataset_name)) data_config = config.DataConfig() episode_descr_config = config.EpisodeDescriptionConfig() use_dag_ontology = (FLAGS.dataset_name in ('ilsvrc_2012', 'ilsvrc_2012_v2') and not FLAGS.ignore_dag_ontology) use_bilevel_ontology = (FLAGS.dataset_name == 'omniglot' and not FLAGS.ignore_bilevel_ontology) data_pipeline = pipeline.make_one_source_episode_pipeline( dataset_spec, use_dag_ontology=use_dag_ontology, use_bilevel_ontology=use_bilevel_ontology, split=FLAGS.split, episode_descr_config=episode_descr_config, # TODO(evcu) Maybe set the following to 0 to prevent shuffling and check # reproducibility of dumping. shuffle_buffer_size=data_config.shuffle_buffer_size, read_buffer_size_bytes=data_config.read_buffer_size_bytes, num_prefetch=data_config.num_prefetch) dataset = data_pipeline.take(FLAGS.num_episodes) images_per_class_dict = {} # Ignoring dataset number since we are loading one dataset. for episode_number, (episode, _) in enumerate(dataset): logging.info('Dumping episode %d', episode_number) train_imgs, train_labels, _, test_imgs, test_labels, _ = episode path_train = utils.get_file_path(FLAGS.output_dir, episode_number, 'train') path_test = utils.get_file_path(FLAGS.output_dir, episode_number, 'test') utils.dump_as_tfrecord(path_train, train_imgs, train_labels) utils.dump_as_tfrecord(path_test, test_imgs, test_labels) images_per_class_dict[os.path.basename(path_train)] = ( utils.get_label_counts(train_labels)) images_per_class_dict[os.path.basename(path_test)] = ( utils.get_label_counts(test_labels)) info_path = utils.get_info_path(FLAGS.output_dir) with tf.io.gfile.GFile(info_path, 'w') as f: f.write(json.dumps(images_per_class_dict, indent=2))
def make_md(lst, method, split='train', image_size=126, **kwargs): if split == 'train': SPLIT = learning_spec.Split.TRAIN elif split == 'val': SPLIT = learning_spec.Split.VALID elif split == 'test': SPLIT = learning_spec.Split.TEST ALL_DATASETS = lst all_dataset_specs = [] for dataset_name in ALL_DATASETS: dataset_records_path = os.path.join(BASE_PATH, dataset_name) dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path) all_dataset_specs.append(dataset_spec) if method == 'episodic': use_bilevel_ontology_list = [False]*len(ALL_DATASETS) use_dag_ontology_list = [False]*len(ALL_DATASETS) # Enable ontology aware sampling for Omniglot and ImageNet. for i, s in enumerate(ALL_DATASETS): if s == 'ilsvrc_2012': use_dag_ontology_list[i] = True if s == 'omniglot': use_bilevel_ontology_list[i] = True variable_ways_shots = config.EpisodeDescriptionConfig( num_query=None, num_support=None, num_ways=None) dataset_episodic = pipeline.make_multisource_episode_pipeline( dataset_spec_list=all_dataset_specs, use_dag_ontology_list=use_dag_ontology_list, use_bilevel_ontology_list=use_bilevel_ontology_list, episode_descr_config=variable_ways_shots, split=SPLIT, image_size=image_size) return dataset_episodic elif method == 'batch': BATCH_SIZE = kwargs['batch_size'] ADD_DATASET_OFFSET = False dataset_batch = pipeline.make_multisource_batch_pipeline( dataset_spec_list=all_dataset_specs, batch_size=BATCH_SIZE, split=SPLIT, image_size=image_size, add_dataset_offset=ADD_DATASET_OFFSET) return dataset_batch
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Report failed checks as they occur and maintain a counter, instead of # raising exceptions right away, so all issues can be reported at once. num_failed_checks = 0 # Load dataset_spec, this should fail if it is absent or incorrect. if FLAGS.dataset_spec_file is None: dataset_spec = dataset_spec_lib.load_dataset_spec( FLAGS.dataset_records_path) else: with tf.io.gfile.GFile(FLAGS.dataset_spec_file, 'r') as f: dataset_spec = json.load( f, object_hook=dataset_spec_lib.as_dataset_spec) dataset_spec.initialize() # 1. Check dataset name dir_name = os.path.basename(os.path.abspath(FLAGS.dataset_records_path)) if dataset_spec.name != dir_name: num_failed_checks += 1 logging.error( 'The dataset name in "dataset_spec.json" (%s) does not match ' 'the name of the directory containing it (%s)', dataset_spec.name, dir_name) # 2. Check name and number of .tfrecords files num_classes = len(dataset_spec.class_names) try: expected_filenames = [ dataset_spec.file_pattern.format(class_id) for class_id in range(num_classes) ] except IndexError: num_failed_checks += 1 err_msg = ( 'The `file_pattern` (%s) did not accept the class number as its only ' 'formatting argument. Using the default (%s).') default_pattern = '{}.tfrecords' logging.error(err_msg, dataset_spec.file_pattern, default_pattern) expected_filenames = [ default_pattern.format(class_id) for class_id in range(num_classes) ] all_filenames = tf.io.gfile.listdir(FLAGS.dataset_records_path) # Heuristic to exclude obviously-not-tfrecords files. tfrecords_filenames = [ f for f in all_filenames if 'tfrecords' in f.lower() ] expected_set = set(expected_filenames) present_set = set(tfrecords_filenames) if set(expected_set) != set(present_set): num_failed_checks += 1 logging.error( 'The tfrecords files in %s do not match the dataset_spec.\n' 'Unexpected files present:\n' '%s\n' 'Expected files not present:\n' '%s', FLAGS.dataset_records_path, sorted(present_set - expected_set), sorted(expected_set - present_set)) # Iterate through each dataset, count examples and check set of targets. # List of (class_id, expected_count, actual_count) triples. bad_counts = [] # List of (filename, class_id, labels). bad_labels = [] for class_id, filename in enumerate(expected_filenames): expected_count = dataset_spec.get_total_images_per_class(class_id) if filename not in tfrecords_filenames: # The tfrecords does not exist, we use a negative count to denote it. bad_counts.append((class_id, expected_count, -1)) bad_labels.append((filename, class_id, set())) continue full_filepath = os.path.join(FLAGS.dataset_records_path, filename) try: count, labels = get_count_and_labels(full_filepath, FLAGS.label_field_name) except tf.errors.InvalidArgumentError: logging.exception( 'Unable to find label (%s) in the tf.Examples of file %s. ' 'Maybe try a different --label_field_name.', FLAGS.label_field_name, filename) # Fall back to counting examples only. count = count_records(full_filepath) labels = set() if count != expected_count: bad_counts.append((class_id, expected_count, count)) if labels != {class_id}: # labels could include class_id among other, incorrect labels. bad_labels.append((filename, class_id, labels)) # 3. Check number of examples if bad_counts: num_failed_checks += 1 logging.error( 'The number of tfrecords in the following files do not match ' 'the expected number of examples in that class.\n' '(filename, expected, actual) # -1 denotes a missing file.\n' '%s', bad_counts) # 4. Check the targets stored in the tfrecords files. if bad_labels: num_failed_checks += 1 logging.error( 'The labels stored inside the tfrecords (in field %s) do not ' 'all match the expected value (class_id).\n' '(filename, class_id, values)\n' '%s', FLAGS.label_field_name, bad_labels) # Report results if num_failed_checks: raise ValueError('%d checks failed. See the error-level logs.' % num_failed_checks)
def __init__(self, path_to_records, batch_config=None, episode_config=[28, 5, 1, 19], valid_episode_config=[28, 5, 1, 19], pool='train', mode='episode'): """ Args: path_to_records: Absolute path of the tfrecords from which data will be generated. Should have the form ~/meta_train or ~/meta_test batch_config: Array-like. In batch mode, controls the size of batches and decoded images generated.Ignored otherwise. [image_size, batch_size] episode_config: Array-like. Describes the episode configuration. If pool='train' and mode='episode', it sets the meta-train episodes configuration generator. If pool='test', it sets the meta-test episodes configuration generator. [image_size, num_ways, num_examples_per_class_in_support, num_total_query_examples] valid_episode_config: Array-like. Sets the episode configuration for the meta-valid episodes generator. It is only relevant in the pool='train' and mode='episode' setting, since both meta-train and meta-valid split could have their own episodes configuration. pool: The split from which images are taken. Only 'train' and 'test' pool are allowed. Automatically create a meta-validation generator if pool='train'. mode: The configuration of the data coming from the meta-train generator. 'batch' and 'episode' are available. """ self.episode_config = episode_config self.valid_episode_config = valid_episode_config self.pool = pool self.mode = mode if self.pool not in ['train', 'test']: raise ValueError( ('In DataGenerator, only \'train\' or \'test\' ' + 'are valid arguments for pool. Received :{}').format( self.pool)) if self.mode not in ['episode', 'batch']: raise ValueError( ('In DataGenerator, only \'episode\' or \'batch\' ' + 'are valid arguments for mode. Received :{}').format( self.mode)) if (self.pool == 'test' and self.mode == 'batch'): raise ValueError(( 'In DataGenerator, batch mode is only available ' + 'at meta-train time. Received pool : {} and mode : {}').format( self.pool, self.mode)) if self.mode == 'batch': try: self.image_size_batch = batch_config[0] self.batch_size = batch_config[1] except: raise ValueError( ('The batch_config argument in DataGenerator ' + 'is not defined properly. Make sure it has the form ' + '[img_size, batch_size]. ' + 'Received batch_config : {}').format(batch_config)) if self.mode == 'episode': try: _, _, _, _ = (self.episode_config[0], self.episode_config[1], self.episode_config[2], self.episode_config[3]) except: raise ValueError( ('The episode config argument in DataGenerator ' + 'is not defined properly. Make sure it has the form ' + '[img_size, num_ways, num_shots, num_query]. ' + 'Received episode_config : {}').format(episode_config)) try: _, _, _, _ = (self.valid_episode_config[0], self.valid_episode_config[1], self.valid_episode_config[2], self.valid_episode_config[3]) except: raise ValueError( ('The episode config argument in DataGenerator ' + 'is not defined properly. Make sure it has the form ' + '[img_size, num_ways, num_shots, num_query]. ' + 'Received episode_config : {}' ).format(valid_episode_config)) self.dataset_spec = dataset_spec_lib.load_dataset_spec(path_to_records) # Loading root path. root_path = os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) gin_path = os.path.join(root_path, 'metadl/gin/default/decoders.gin') gin.parse_config_file(gin_path) self.meta_train_pipeline = None self.meta_test_pipeline = None self.meta_valid_pipeline = None logging.info('Creating {} generator for meta-{} dataset.'.format( self.mode, self.pool)) self.set_fixed_episode_config() if self.pool == 'train': if self.mode == 'episode': self.generate_meta_train_episodes_pipeline() else: self.generate_meta_train_batch_pipeline() else: self.generate_meta_test_episodes_pipeline()
def test_full_ways_dataset(self, source, md_source, md_version, meta_split, remap_labels): dataset_spec = dataset_spec_lib.load_dataset_spec( os.path.join(FLAGS.meta_dataset_path, md_source)) allowed_classes = set() forbidden_classes = set() for ms in ('train', 'valid', 'test'): (allowed_classes if ms == meta_split else forbidden_classes).update( dataset_spec.get_classes(getattr(learning_spec.Split, ms.upper()))) # For MD-v2's ilsvrc_2012 source, train classes are sorted by class name, # whereas the TFDS implementation intentionally keeps the v1 class order. if remap_labels: # The first argsort tells us where to look in class_names for position i # in the sorted class list. The second argsort reverses that: it tells us, # for position j in class_names, where to place that class in the sorted # class list. class_names = tfds.builder( 'meta_dataset/ilsvrc_2012', data_dir=FLAGS.tfds_path ).info.metadata['class_names'] label_map = np.argsort(np.argsort(class_names)) label_remap_fn = np.vectorize(lambda x: label_map[x]) total_images = sum(dataset_spec.get_total_images_per_class(l) for l in allowed_classes) shuffle_buffer_size = min(10_000, total_images) batch_size = min(1024, total_images) dataset = api.full_ways_dataset( md_source=source, md_version=md_version, meta_split=meta_split, decoders={'image': tfds.decode.SkipDecoding()}, data_dir=FLAGS.tfds_path, ).shuffle( shuffle_buffer_size ).batch(batch_size, drop_remainder=False).prefetch(1) entropies = [] for batch in dataset.as_numpy_iterator(): labels = batch['label'] if remap_labels: labels = label_remap_fn(labels) # With the exception of the remainder batch, we expect classes to be # distributed more or less uniformly within batches. if len(labels) == batch_size: label_counts = collections.Counter(labels) p = np.array( [label_counts[l] for l in allowed_classes], dtype='float32' ) / batch_size entropies.append(-np.nan_to_num(p * np.log(p)).sum()) label_set = set(labels) self.assertContainsSubset(label_set, allowed_classes) self.assertNoCommonElements(label_set, forbidden_classes) # We are happy if at least 75% of the (non-remainder) batches have a # class label entropy is at least 50% of the maximum value possible (which # corresponds to a uniform distribution). self.assertGreater( (np.array(entropies) > 0.5 * np.log(len(allowed_classes))).mean(), 0.75)
def test_episode_dataset_matches_meta_dataset(self, source, md_source, md_version, meta_split, remap_labels): gin.parse_config_file(tfds.core.as_path(meta_dataset.__file__).parent / 'learn/gin/setups/data_config_tfds.gin') gin.parse_config(_DETERMINISTIC_CONFIG) data_config = config_lib.DataConfig() seed = 20210917 num_episodes = 10 builder = tfds.builder( 'meta_dataset', config=source, data_dir=FLAGS.tfds_path) sampling.RNG = np.random.RandomState(seed) tfds_episode_dataset = api.episode_dataset( builder, md_version, meta_split, source_id=0) # For MD-v2's ilsvrc_2012 source, train classes are sorted by class name, # whereas the TFDS implementation intentionally keeps the v1 class order. if remap_labels: # The first argsort tells us where to look in class_names for position i # in the sorted class list. The second argsort reverses that: it tells us, # for position j in class_names, where to place that class in the sorted # class list. label_map = np.argsort(np.argsort(builder.info.metadata['class_names'])) label_remap_fn = np.vectorize(lambda x: label_map[x]) tfds_episodes = list( tfds_episode_dataset.take(num_episodes).as_numpy_iterator()) dataset_spec = dataset_spec_lib.load_dataset_spec( os.path.join(FLAGS.meta_dataset_path, md_source)) sampling.RNG = np.random.RandomState(seed) # We should not skip TFExample decoding in the original Meta-Dataset # implementation. gin.bind_parameter('ImageDecoder.skip_tfexample_decoding', False) md_episode_dataset = pipeline.make_one_source_episode_pipeline( dataset_spec, use_dag_ontology='ilsvrc_2012' in dataset_spec.name, use_bilevel_ontology=dataset_spec.name == 'omniglot', split=getattr(learning_spec.Split, meta_split.upper()), episode_descr_config=config_lib.EpisodeDescriptionConfig(), pool=None, shuffle_buffer_size=data_config.shuffle_buffer_size, image_size=data_config.image_height ) md_episodes = list( md_episode_dataset.take(num_episodes).as_numpy_iterator()) for (tfds_episode, tfds_source_id), (md_episode, md_source_id) in zip( tfds_episodes, md_episodes): np.testing.assert_equal(tfds_source_id, md_source_id) if remap_labels: tfds_episode = list(tfds_episode) tfds_episode[2] = label_remap_fn(tfds_episode[2]) tfds_episode[5] = label_remap_fn(tfds_episode[5]) tfds_episode = tuple(tfds_episode) for tfds_tensor, md_tensor in zip(tfds_episode, md_episode): np.testing.assert_allclose(tfds_tensor, md_tensor)
break yield idx, episode # 5 SPLIT = learning_spec.Split.TRAIN ### Reading datasets #ALL_DATASETS = ['aircraft', 'cu_birds', 'dtd', 'fungi', 'ilsvrc_2012', # 'omniglot', 'quickdraw', 'vgg_flower'] ALL_DATASETS = ['cu_birds'] all_dataset_specs = [] for dataset_name in ALL_DATASETS: dataset_records_path = os.path.join(BASE_PATH, dataset_name) dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path) all_dataset_specs.append(dataset_spec) ## (1) Episodic Mode use_bilevel_ontology_list = [False] * len(ALL_DATASETS) use_dag_ontology_list = [False] * len(ALL_DATASETS) # Enable ontology aware sampling for Omniglot and ImageNet. #use_bilevel_ontology_list[5] = True #use_dag_ontology_list[4] = True variable_ways_shots = config.EpisodeDescriptionConfig(num_query=None, num_support=None, num_ways=None) dataset_episodic = pipeline.make_multisource_episode_pipeline( dataset_spec_list=all_dataset_specs,