def generate_meta_train_episodes_pipeline(self): """ Creates an episode generator for both meta-train and meta-valid splits. ---------------------------------------------------------------------- Details of some arguments inside the function : The following arguments are ignored since we are using fixed episode description : - max_ways_upper_bound, max_num_query, max_support_{}, min/max_log_weight dag_ontology : only relevant for ImageNet dataset. Ignored. bilevel_ontology : Whether to ignore Omniglot's DAG ontology when sampling classes from it. Ignored. min_examples : 0 means that no class is discarded from having a too small number of examples. """ self.meta_train_pipeline = pipeline.make_one_source_episode_pipeline( dataset_spec=self.dataset_spec, use_dag_ontology=False, use_bilevel_ontology=False, split=learning_spec.Split.TRAIN, episode_descr_config=self.fixed_ways_shots, pool=None, shuffle_buffer_size=None, read_buffer_size_bytes=None, num_prefetch=0, image_size=self.episode_config[0], num_to_take=None ) self.meta_valid_pipeline = pipeline.make_one_source_episode_pipeline( dataset_spec=self.dataset_spec, use_dag_ontology=False, use_bilevel_ontology=False, split=learning_spec.Split.VALID, episode_descr_config=self.fixed_ways_shots_valid, pool=None, shuffle_buffer_size=None, read_buffer_size_bytes=None, num_prefetch=0, image_size=self.valid_episode_config[0], num_to_take=None ) logging.info('Meta-Valid episode config : {}'.format( self.valid_episode_config)) logging.info('Meta-Train episode config : {}'.format( self.episode_config))
def generate_meta_test_episodes_pipeline(self): """Creates the episode generator for the meta-test dataset. Notice that at meta-test time, the meta-learning algorithms always receive data in the form of episodes. Also, participantscan't control these episodes' setting. ---------------------------------------------------------------------- Details of some arguments inside the function : The following arguments are ignored since we are using fixed episode description : - max_ways_upper_bound, max_num_query, max_support_{}, min/max_log_weight dag_ontology : only relevant for ImageNet dataset. Ignored. bilevel_ontology : Whether to ignore Omniglot's DAG ontology when sampling classes from it. Ignored. min_examples : 0 means that no class is discarded from having a too small number of examples. """ self.meta_test_pipeline = pipeline.make_one_source_episode_pipeline( dataset_spec=self.dataset_spec, use_dag_ontology=False, use_bilevel_ontology=False, split=learning_spec.Split.TEST, episode_descr_config=self.fixed_ways_shots, pool=None, shuffle_buffer_size=3000, read_buffer_size_bytes=None, num_prefetch=0, image_size=self.episode_config[0], num_to_take=None) logging.info('Meta-test episode config : {}'.format( self.episode_config))
def _init_single_source_dataset(self, dataset_name, split, episode_description): dataset_spec = self._get_dataset_spec(dataset_name) self.specs_dict[split] = dataset_spec # Enable ontology aware sampling for Omniglot and ImageNet. use_bilevel_ontology = False if 'omniglot' in dataset_name: use_bilevel_ontology = True use_dag_ontology = False if 'ilsvrc_2012' in dataset_name: use_dag_ontology = True single_source_pipeline = pipeline.make_one_source_episode_pipeline( dataset_spec=dataset_spec, use_dag_ontology=use_dag_ontology, use_bilevel_ontology=use_bilevel_ontology, split=split, episode_descr_config=episode_description, image_size=84, shuffle_buffer_size=1000) iterator = single_source_pipeline.make_one_shot_iterator() return iterator.get_next()
def generate_meta_train_batch_pipeline(self): """ Creates a batch generator for examples coming from the meta-train split. Also creates an episode generator for examples coming from the meta-valid split. Indeed, the way data comes from the meta-valid split should match its meta-test split counter-part. The meta-valid data generator will use the episode_config for the episodes description if its own configuration is not provided. ---------------------------------------------------------------------- Details of some arguments inside the function : The following arguments are ignored since we are using fixed episode description : - max_ways_upper_bound, max_num_query, max_support_{}, min/max_log_weight dag_ontology : only relevant for ImageNet dataset. Ignored. bilevel_ontology : Whether to ignore Omniglot's DAG ontology when sampling classes from it. Ignored. min_examples : 0 means that no class is discarded from having a too small number of examples. """ self.meta_train_pipeline = pipeline.make_one_source_batch_pipeline( dataset_spec=self.dataset_spec, split=learning_spec.Split.TRAIN, batch_size= self.batch_size, pool=None, shuffle_buffer_size=None, read_buffer_size_bytes=None, num_prefetch=0, image_size=self.image_size_batch, num_to_take=None ) self.meta_valid_pipeline = pipeline.make_one_source_episode_pipeline( dataset_spec=self.dataset_spec, use_dag_ontology=False, use_bilevel_ontology=False, split=learning_spec.Split.VALID, episode_descr_config=self.fixed_ways_shots_valid, pool=None, shuffle_buffer_size=None, read_buffer_size_bytes=None, num_prefetch=0, image_size=self.valid_episode_config[0], num_to_take=None ) logging.info('Meta-valid episode config : {}'.format( self.valid_episode_config)) logging.info('Meta-train batch config : [{}, {}]'.format( self.batch_size, self.image_size_batch))
def _init_dataset(self, dataset, split, episode_description): dataset_records_path = os.path.join(self.data_path, dataset) dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path) single_source_pipeline = pipeline.make_one_source_episode_pipeline( dataset_spec=dataset_spec, use_dag_ontology=False, use_bilevel_ontology=False, split=split, episode_descr_config=episode_description, image_size=84) iterator = single_source_pipeline.make_one_shot_iterator() return iterator.get_next()
def main(unused_argv): logging.info(FLAGS.output_dir) tf.io.gfile.makedirs(FLAGS.output_dir) gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings, finalize_config=True) dataset_spec = dataset_spec_lib.load_dataset_spec( os.path.join(FLAGS.records_root_dir, FLAGS.dataset_name)) data_config = config.DataConfig() episode_descr_config = config.EpisodeDescriptionConfig() use_dag_ontology = (FLAGS.dataset_name in ('ilsvrc_2012', 'ilsvrc_2012_v2') and not FLAGS.ignore_dag_ontology) use_bilevel_ontology = (FLAGS.dataset_name == 'omniglot' and not FLAGS.ignore_bilevel_ontology) data_pipeline = pipeline.make_one_source_episode_pipeline( dataset_spec, use_dag_ontology=use_dag_ontology, use_bilevel_ontology=use_bilevel_ontology, split=FLAGS.split, episode_descr_config=episode_descr_config, # TODO(evcu) Maybe set the following to 0 to prevent shuffling and check # reproducibility of dumping. shuffle_buffer_size=data_config.shuffle_buffer_size, read_buffer_size_bytes=data_config.read_buffer_size_bytes, num_prefetch=data_config.num_prefetch) dataset = data_pipeline.take(FLAGS.num_episodes) images_per_class_dict = {} # Ignoring dataset number since we are loading one dataset. for episode_number, (episode, _) in enumerate(dataset): logging.info('Dumping episode %d', episode_number) train_imgs, train_labels, _, test_imgs, test_labels, _ = episode path_train = utils.get_file_path(FLAGS.output_dir, episode_number, 'train') path_test = utils.get_file_path(FLAGS.output_dir, episode_number, 'test') utils.dump_as_tfrecord(path_train, train_imgs, train_labels) utils.dump_as_tfrecord(path_test, test_imgs, test_labels) images_per_class_dict[os.path.basename(path_train)] = ( utils.get_label_counts(train_labels)) images_per_class_dict[os.path.basename(path_test)] = ( utils.get_label_counts(test_labels)) info_path = utils.get_info_path(FLAGS.output_dir) with tf.io.gfile.GFile(info_path, 'w') as f: f.write(json.dumps(images_per_class_dict, indent=2))
def _init_single_source_dataset(self, dataset_name, split, episode_description): dataset_records_path = os.path.join(self.data_path, dataset_name) dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path) # Enable ontology aware sampling for Omniglot and ImageNet. use_bilevel_ontology = False if 'omniglot' in dataset_name: use_bilevel_ontology = True use_dag_ontology = False if 'ilsvrc_2012' in dataset_name: use_dag_ontology = True single_source_pipeline = pipeline.make_one_source_episode_pipeline( dataset_spec=dataset_spec, use_dag_ontology=use_dag_ontology, use_bilevel_ontology=use_bilevel_ontology, split=split, episode_descr_config=episode_description, image_size=84) iterator = single_source_pipeline.make_one_shot_iterator() return iterator.get_next()
def build_episode(self, split): """Builds an EpisodeDataset containing the next data for "split". Args: split: A string, either 'train', 'valid', or 'test'. Returns: An EpisodeDataset. """ shuffle_buffer_size = self.data_config.shuffle_buffer_size read_buffer_size_bytes = self.data_config.read_buffer_size_bytes benchmark_spec = (self.valid_benchmark_spec if split == 'valid' else self.benchmark_spec) (_, image_shape, dataset_spec_list, has_dag_ontology, has_bilevel_ontology) = benchmark_spec episode_spec = self.split_episode_or_batch_specs[split] dataset_split, num_classes, num_train_examples, num_test_examples = \ episode_spec # TODO(lamblinp): Support non-square shapes if necessary. For now, all # images are resized to square, even if it changes the aspect ratio. image_size = image_shape[0] if image_shape[1] != image_size: raise ValueError( 'Expected a square image shape, not {}'.format(image_shape)) # TODO(lamblinp): pass specs directly to the pipeline builder. # TODO(lamblinp): move the special case directly in make_..._pipeline if len(dataset_spec_list) == 1: use_dag_ontology = has_dag_ontology[0] if self.eval_finegrainedness or self.eval_imbalance_dataset: use_dag_ontology = False data_pipeline = pipeline.make_one_source_episode_pipeline( dataset_spec_list[0], use_dag_ontology=use_dag_ontology, use_bilevel_ontology=has_bilevel_ontology[0], split=dataset_split, num_ways=num_classes, num_support=num_train_examples, num_query=num_test_examples, shuffle_buffer_size=shuffle_buffer_size, read_buffer_size_bytes=read_buffer_size_bytes, image_size=image_size) else: data_pipeline = pipeline.make_multisource_episode_pipeline( dataset_spec_list, use_dag_ontology_list=has_dag_ontology, use_bilevel_ontology_list=has_bilevel_ontology, split=dataset_split, num_ways=num_classes, num_support=num_train_examples, num_query=num_test_examples, shuffle_buffer_size=shuffle_buffer_size, read_buffer_size_bytes=read_buffer_size_bytes, image_size=image_size) data_pipeline = apply_dataset_options(data_pipeline) iterator = data_pipeline.make_one_shot_iterator() (support_images, support_labels, support_class_ids, query_images, query_labels, query_class_ids) = iterator.get_next() return providers.EpisodeDataset(train_images=support_images, test_images=query_images, train_labels=support_labels, test_labels=query_labels, train_class_ids=support_class_ids, test_class_ids=query_class_ids)
def test_episode_dataset_matches_meta_dataset(self, source, md_source, md_version, meta_split, remap_labels): gin.parse_config_file(tfds.core.as_path(meta_dataset.__file__).parent / 'learn/gin/setups/data_config_tfds.gin') gin.parse_config(_DETERMINISTIC_CONFIG) data_config = config_lib.DataConfig() seed = 20210917 num_episodes = 10 builder = tfds.builder( 'meta_dataset', config=source, data_dir=FLAGS.tfds_path) sampling.RNG = np.random.RandomState(seed) tfds_episode_dataset = api.episode_dataset( builder, md_version, meta_split, source_id=0) # For MD-v2's ilsvrc_2012 source, train classes are sorted by class name, # whereas the TFDS implementation intentionally keeps the v1 class order. if remap_labels: # The first argsort tells us where to look in class_names for position i # in the sorted class list. The second argsort reverses that: it tells us, # for position j in class_names, where to place that class in the sorted # class list. label_map = np.argsort(np.argsort(builder.info.metadata['class_names'])) label_remap_fn = np.vectorize(lambda x: label_map[x]) tfds_episodes = list( tfds_episode_dataset.take(num_episodes).as_numpy_iterator()) dataset_spec = dataset_spec_lib.load_dataset_spec( os.path.join(FLAGS.meta_dataset_path, md_source)) sampling.RNG = np.random.RandomState(seed) # We should not skip TFExample decoding in the original Meta-Dataset # implementation. gin.bind_parameter('ImageDecoder.skip_tfexample_decoding', False) md_episode_dataset = pipeline.make_one_source_episode_pipeline( dataset_spec, use_dag_ontology='ilsvrc_2012' in dataset_spec.name, use_bilevel_ontology=dataset_spec.name == 'omniglot', split=getattr(learning_spec.Split, meta_split.upper()), episode_descr_config=config_lib.EpisodeDescriptionConfig(), pool=None, shuffle_buffer_size=data_config.shuffle_buffer_size, image_size=data_config.image_height ) md_episodes = list( md_episode_dataset.take(num_episodes).as_numpy_iterator()) for (tfds_episode, tfds_source_id), (md_episode, md_source_id) in zip( tfds_episodes, md_episodes): np.testing.assert_equal(tfds_source_id, md_source_id) if remap_labels: tfds_episode = list(tfds_episode) tfds_episode[2] = label_remap_fn(tfds_episode[2]) tfds_episode[5] = label_remap_fn(tfds_episode[5]) tfds_episode = tuple(tfds_episode) for tfds_tensor, md_tensor in zip(tfds_episode, md_episode): np.testing.assert_allclose(tfds_tensor, md_tensor)