Ejemplo n.º 1
0
    def generate_meta_train_episodes_pipeline(self):
        """ Creates an episode generator for both meta-train and meta-valid 
        splits.

        ----------------------------------------------------------------------
        Details of some arguments inside the function : 

        The following arguments are ignored since we are using fixed episode
        description : 
            - max_ways_upper_bound, max_num_query, max_support_{}, 
                min/max_log_weight

        dag_ontology : only relevant for ImageNet dataset. Ignored.
        bilevel_ontology : Whether to ignore Omniglot's DAG ontology when
            sampling classes from it. Ignored.
        min_examples : 0 means that no class is discarded from having a too 
            small number of examples.
        """
        self.meta_train_pipeline = pipeline.make_one_source_episode_pipeline(
            dataset_spec=self.dataset_spec,
            use_dag_ontology=False,
            use_bilevel_ontology=False,
            split=learning_spec.Split.TRAIN,
            episode_descr_config=self.fixed_ways_shots,
            pool=None,
            shuffle_buffer_size=None,
            read_buffer_size_bytes=None,
            num_prefetch=0,
            image_size=self.episode_config[0],
            num_to_take=None
        )
        self.meta_valid_pipeline = pipeline.make_one_source_episode_pipeline(
            dataset_spec=self.dataset_spec,
            use_dag_ontology=False,
            use_bilevel_ontology=False,
            split=learning_spec.Split.VALID,
            episode_descr_config=self.fixed_ways_shots_valid,
            pool=None,
            shuffle_buffer_size=None,
            read_buffer_size_bytes=None,
            num_prefetch=0,
            image_size=self.valid_episode_config[0],
            num_to_take=None
        )
        logging.info('Meta-Valid episode config : {}'.format(
            self.valid_episode_config))
        logging.info('Meta-Train episode config : {}'.format(
            self.episode_config))
Ejemplo n.º 2
0
    def generate_meta_test_episodes_pipeline(self):
        """Creates the episode generator for the meta-test dataset. 
        
        Notice that at meta-test time, the meta-learning algorithms always
        receive data in the form of episodes. Also, participantscan't control 
        these episodes' setting.

        ----------------------------------------------------------------------
        Details of some arguments inside the function : 

        The following arguments are ignored since we are using fixed episode
        description : 
            - max_ways_upper_bound, max_num_query, max_support_{}, 
                min/max_log_weight

        dag_ontology : only relevant for ImageNet dataset. Ignored.
        bilevel_ontology : Whether to ignore Omniglot's DAG ontology when
            sampling classes from it. Ignored.
        min_examples : 0 means that no class is discarded from having a too 
            small number of examples.
        """
        self.meta_test_pipeline = pipeline.make_one_source_episode_pipeline(
            dataset_spec=self.dataset_spec,
            use_dag_ontology=False,
            use_bilevel_ontology=False,
            split=learning_spec.Split.TEST,
            episode_descr_config=self.fixed_ways_shots,
            pool=None,
            shuffle_buffer_size=3000,
            read_buffer_size_bytes=None,
            num_prefetch=0,
            image_size=self.episode_config[0],
            num_to_take=None)
        logging.info('Meta-test episode config : {}'.format(
            self.episode_config))
Ejemplo n.º 3
0
    def _init_single_source_dataset(self, dataset_name, split,
                                    episode_description):
        dataset_spec = self._get_dataset_spec(dataset_name)
        self.specs_dict[split] = dataset_spec

        # Enable ontology aware sampling for Omniglot and ImageNet.
        use_bilevel_ontology = False
        if 'omniglot' in dataset_name:
            use_bilevel_ontology = True

        use_dag_ontology = False
        if 'ilsvrc_2012' in dataset_name:
            use_dag_ontology = True

        single_source_pipeline = pipeline.make_one_source_episode_pipeline(
            dataset_spec=dataset_spec,
            use_dag_ontology=use_dag_ontology,
            use_bilevel_ontology=use_bilevel_ontology,
            split=split,
            episode_descr_config=episode_description,
            image_size=84,
            shuffle_buffer_size=1000)

        iterator = single_source_pipeline.make_one_shot_iterator()
        return iterator.get_next()
Ejemplo n.º 4
0
    def generate_meta_train_batch_pipeline(self):
        """ Creates a batch generator for examples coming from the meta-train 
        split. Also creates an episode generator for examples coming from the 
        meta-valid split. Indeed, the way data comes from the meta-valid split
        should match its meta-test split counter-part.

        The meta-valid data generator will use the episode_config for the 
        episodes description if its own configuration is not provided.
        ----------------------------------------------------------------------
        Details of some arguments inside the function : 

        The following arguments are ignored since we are using fixed episode
        description : 
            - max_ways_upper_bound, max_num_query, max_support_{}, 
                min/max_log_weight

        dag_ontology : only relevant for ImageNet dataset. Ignored.
        bilevel_ontology : Whether to ignore Omniglot's DAG ontology when
            sampling classes from it. Ignored.
        min_examples : 0 means that no class is discarded from having a too 
            small number of examples.
        """
        self.meta_train_pipeline = pipeline.make_one_source_batch_pipeline(
            dataset_spec=self.dataset_spec,
            split=learning_spec.Split.TRAIN,
            batch_size= self.batch_size,
            pool=None,
            shuffle_buffer_size=None,
            read_buffer_size_bytes=None,
            num_prefetch=0,
            image_size=self.image_size_batch,
            num_to_take=None
        )
        self.meta_valid_pipeline = pipeline.make_one_source_episode_pipeline(
            dataset_spec=self.dataset_spec,
            use_dag_ontology=False,
            use_bilevel_ontology=False,
            split=learning_spec.Split.VALID,
            episode_descr_config=self.fixed_ways_shots_valid,
            pool=None,
            shuffle_buffer_size=None,
            read_buffer_size_bytes=None,
            num_prefetch=0,
            image_size=self.valid_episode_config[0],
            num_to_take=None
        )
        logging.info('Meta-valid episode config : {}'.format(
            self.valid_episode_config))
        logging.info('Meta-train batch config : [{}, {}]'.format(
            self.batch_size, self.image_size_batch))
Ejemplo n.º 5
0
    def _init_dataset(self, dataset, split, episode_description):
        dataset_records_path = os.path.join(self.data_path, dataset)
        dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path)

        single_source_pipeline = pipeline.make_one_source_episode_pipeline(
            dataset_spec=dataset_spec,
            use_dag_ontology=False,
            use_bilevel_ontology=False,
            split=split,
            episode_descr_config=episode_description,
            image_size=84)

        iterator = single_source_pipeline.make_one_shot_iterator()
        return iterator.get_next()
Ejemplo n.º 6
0
def main(unused_argv):
    logging.info(FLAGS.output_dir)
    tf.io.gfile.makedirs(FLAGS.output_dir)
    gin.parse_config_files_and_bindings(FLAGS.gin_config,
                                        FLAGS.gin_bindings,
                                        finalize_config=True)
    dataset_spec = dataset_spec_lib.load_dataset_spec(
        os.path.join(FLAGS.records_root_dir, FLAGS.dataset_name))
    data_config = config.DataConfig()
    episode_descr_config = config.EpisodeDescriptionConfig()
    use_dag_ontology = (FLAGS.dataset_name in ('ilsvrc_2012', 'ilsvrc_2012_v2')
                        and not FLAGS.ignore_dag_ontology)
    use_bilevel_ontology = (FLAGS.dataset_name == 'omniglot'
                            and not FLAGS.ignore_bilevel_ontology)
    data_pipeline = pipeline.make_one_source_episode_pipeline(
        dataset_spec,
        use_dag_ontology=use_dag_ontology,
        use_bilevel_ontology=use_bilevel_ontology,
        split=FLAGS.split,
        episode_descr_config=episode_descr_config,
        # TODO(evcu) Maybe set the following to 0 to prevent shuffling and check
        # reproducibility of dumping.
        shuffle_buffer_size=data_config.shuffle_buffer_size,
        read_buffer_size_bytes=data_config.read_buffer_size_bytes,
        num_prefetch=data_config.num_prefetch)
    dataset = data_pipeline.take(FLAGS.num_episodes)

    images_per_class_dict = {}
    # Ignoring dataset number since we are loading one dataset.
    for episode_number, (episode, _) in enumerate(dataset):
        logging.info('Dumping episode %d', episode_number)
        train_imgs, train_labels, _, test_imgs, test_labels, _ = episode
        path_train = utils.get_file_path(FLAGS.output_dir, episode_number,
                                         'train')
        path_test = utils.get_file_path(FLAGS.output_dir, episode_number,
                                        'test')
        utils.dump_as_tfrecord(path_train, train_imgs, train_labels)
        utils.dump_as_tfrecord(path_test, test_imgs, test_labels)
        images_per_class_dict[os.path.basename(path_train)] = (
            utils.get_label_counts(train_labels))
        images_per_class_dict[os.path.basename(path_test)] = (
            utils.get_label_counts(test_labels))
    info_path = utils.get_info_path(FLAGS.output_dir)
    with tf.io.gfile.GFile(info_path, 'w') as f:
        f.write(json.dumps(images_per_class_dict, indent=2))
Ejemplo n.º 7
0
    def _init_single_source_dataset(self, dataset_name, split, episode_description):
        dataset_records_path = os.path.join(self.data_path, dataset_name)
        dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path)

        # Enable ontology aware sampling for Omniglot and ImageNet.
        use_bilevel_ontology = False
        if 'omniglot' in dataset_name:
            use_bilevel_ontology = True

        use_dag_ontology = False
        if 'ilsvrc_2012' in dataset_name:
            use_dag_ontology = True

        single_source_pipeline = pipeline.make_one_source_episode_pipeline(
            dataset_spec=dataset_spec,
            use_dag_ontology=use_dag_ontology,
            use_bilevel_ontology=use_bilevel_ontology,
            split=split,
            episode_descr_config=episode_description,
            image_size=84)

        iterator = single_source_pipeline.make_one_shot_iterator()
        return iterator.get_next()
Ejemplo n.º 8
0
    def build_episode(self, split):
        """Builds an EpisodeDataset containing the next data for "split".

    Args:
      split: A string, either 'train', 'valid', or 'test'.

    Returns:
      An EpisodeDataset.
    """
        shuffle_buffer_size = self.data_config.shuffle_buffer_size
        read_buffer_size_bytes = self.data_config.read_buffer_size_bytes
        benchmark_spec = (self.valid_benchmark_spec
                          if split == 'valid' else self.benchmark_spec)
        (_, image_shape, dataset_spec_list, has_dag_ontology,
         has_bilevel_ontology) = benchmark_spec
        episode_spec = self.split_episode_or_batch_specs[split]
        dataset_split, num_classes, num_train_examples, num_test_examples = \
            episode_spec
        # TODO(lamblinp): Support non-square shapes if necessary. For now, all
        # images are resized to square, even if it changes the aspect ratio.
        image_size = image_shape[0]
        if image_shape[1] != image_size:
            raise ValueError(
                'Expected a square image shape, not {}'.format(image_shape))

        # TODO(lamblinp): pass specs directly to the pipeline builder.
        # TODO(lamblinp): move the special case directly in make_..._pipeline
        if len(dataset_spec_list) == 1:

            use_dag_ontology = has_dag_ontology[0]
            if self.eval_finegrainedness or self.eval_imbalance_dataset:
                use_dag_ontology = False
            data_pipeline = pipeline.make_one_source_episode_pipeline(
                dataset_spec_list[0],
                use_dag_ontology=use_dag_ontology,
                use_bilevel_ontology=has_bilevel_ontology[0],
                split=dataset_split,
                num_ways=num_classes,
                num_support=num_train_examples,
                num_query=num_test_examples,
                shuffle_buffer_size=shuffle_buffer_size,
                read_buffer_size_bytes=read_buffer_size_bytes,
                image_size=image_size)
        else:
            data_pipeline = pipeline.make_multisource_episode_pipeline(
                dataset_spec_list,
                use_dag_ontology_list=has_dag_ontology,
                use_bilevel_ontology_list=has_bilevel_ontology,
                split=dataset_split,
                num_ways=num_classes,
                num_support=num_train_examples,
                num_query=num_test_examples,
                shuffle_buffer_size=shuffle_buffer_size,
                read_buffer_size_bytes=read_buffer_size_bytes,
                image_size=image_size)
            data_pipeline = apply_dataset_options(data_pipeline)

        iterator = data_pipeline.make_one_shot_iterator()
        (support_images, support_labels, support_class_ids, query_images,
         query_labels, query_class_ids) = iterator.get_next()
        return providers.EpisodeDataset(train_images=support_images,
                                        test_images=query_images,
                                        train_labels=support_labels,
                                        test_labels=query_labels,
                                        train_class_ids=support_class_ids,
                                        test_class_ids=query_class_ids)
Ejemplo n.º 9
0
  def test_episode_dataset_matches_meta_dataset(self,
                                                source,
                                                md_source,
                                                md_version,
                                                meta_split,
                                                remap_labels):

    gin.parse_config_file(tfds.core.as_path(meta_dataset.__file__).parent /
                          'learn/gin/setups/data_config_tfds.gin')
    gin.parse_config(_DETERMINISTIC_CONFIG)
    data_config = config_lib.DataConfig()
    seed = 20210917
    num_episodes = 10

    builder = tfds.builder(
        'meta_dataset', config=source, data_dir=FLAGS.tfds_path)
    sampling.RNG = np.random.RandomState(seed)
    tfds_episode_dataset = api.episode_dataset(
        builder, md_version, meta_split, source_id=0)
    # For MD-v2's ilsvrc_2012 source, train classes are sorted by class name,
    # whereas the TFDS implementation intentionally keeps the v1 class order.
    if remap_labels:
      # The first argsort tells us where to look in class_names for position i
      # in the sorted class list. The second argsort reverses that: it tells us,
      # for position j in class_names, where to place that class in the sorted
      # class list.
      label_map = np.argsort(np.argsort(builder.info.metadata['class_names']))
      label_remap_fn = np.vectorize(lambda x: label_map[x])
    tfds_episodes = list(
        tfds_episode_dataset.take(num_episodes).as_numpy_iterator())

    dataset_spec = dataset_spec_lib.load_dataset_spec(
        os.path.join(FLAGS.meta_dataset_path, md_source))
    sampling.RNG = np.random.RandomState(seed)
    # We should not skip TFExample decoding in the original Meta-Dataset
    # implementation.
    gin.bind_parameter('ImageDecoder.skip_tfexample_decoding', False)
    md_episode_dataset = pipeline.make_one_source_episode_pipeline(
        dataset_spec,
        use_dag_ontology='ilsvrc_2012' in dataset_spec.name,
        use_bilevel_ontology=dataset_spec.name == 'omniglot',
        split=getattr(learning_spec.Split, meta_split.upper()),
        episode_descr_config=config_lib.EpisodeDescriptionConfig(),
        pool=None,
        shuffle_buffer_size=data_config.shuffle_buffer_size,
        image_size=data_config.image_height
    )
    md_episodes = list(
        md_episode_dataset.take(num_episodes).as_numpy_iterator())

    for (tfds_episode, tfds_source_id), (md_episode, md_source_id) in zip(
        tfds_episodes, md_episodes):

      np.testing.assert_equal(tfds_source_id, md_source_id)

      if remap_labels:
        tfds_episode = list(tfds_episode)
        tfds_episode[2] = label_remap_fn(tfds_episode[2])
        tfds_episode[5] = label_remap_fn(tfds_episode[5])
        tfds_episode = tuple(tfds_episode)

      for tfds_tensor, md_tensor in zip(tfds_episode, md_episode):
        np.testing.assert_allclose(tfds_tensor, md_tensor)