Ejemplo n.º 1
0
    def _init_multi_source_dataset(self, items, split, episode_description):
        dataset_specs = []
        for dataset_name in items:
            dataset_records_path = os.path.join(self.data_path, dataset_name)
            dataset_spec = dataset_spec_lib.load_dataset_spec(
                dataset_records_path)
            dataset_specs.append(dataset_spec)

        use_bilevel_ontology_list = [False] * len(items)
        use_dag_ontology_list = [False] * len(items)
        # Enable ontology aware sampling for Omniglot and ImageNet.
        if 'omniglot' in items:
            use_bilevel_ontology_list[items.index('omniglot')] = True
        if 'ilsvrc_2012' in items:
            use_dag_ontology_list[items.index('ilsvrc_2012')] = True

        multi_source_pipeline = pipeline.make_multisource_episode_pipeline(
            dataset_spec_list=dataset_specs,
            use_dag_ontology_list=use_dag_ontology_list,
            use_bilevel_ontology_list=use_bilevel_ontology_list,
            split=split,
            episode_descr_config=episode_description,
            image_size=84)

        iterator = multi_source_pipeline.make_one_shot_iterator()
        return iterator.get_next()
Ejemplo n.º 2
0
  def test_meta_dataset(self):
    gin.parse_config_file(tfds.core.as_path(meta_dataset.__file__).parent /
                          'learn/gin/setups/data_config_tfds.gin')
    gin.parse_config(_DETERMINISTIC_CONFIG)
    data_config = config_lib.DataConfig()
    seed = 20210917
    num_episodes = 10
    meta_split = 'valid'
    md_sources = ('aircraft', 'cu_birds', 'vgg_flower')

    sampling.RNG = np.random.RandomState(seed)
    tfds_episode_dataset = api.meta_dataset(
        md_sources=md_sources, md_version='v1', meta_split=meta_split,
        source_sampling_seed=seed + 1, data_dir=FLAGS.tfds_path)
    tfds_episodes = list(
        tfds_episode_dataset.take(num_episodes).as_numpy_iterator())

    sampling.RNG = np.random.RandomState(seed)
    dataset_spec_list = [
        dataset_spec_lib.load_dataset_spec(os.path.join(FLAGS.meta_dataset_path,
                                                        md_source))
        for md_source in md_sources
    ]
    # We should not skip TFExample decoding in the original Meta-Dataset
    # implementation. The kwarg defaults to False when the class is defined, but
    # the call to `gin.parse_config_file` above changes the default value to
    # True, which is why we have to explicitly bind a new default value here.
    gin.bind_parameter('ImageDecoder.skip_tfexample_decoding', False)
    md_episode_dataset = pipeline.make_multisource_episode_pipeline(
        dataset_spec_list=dataset_spec_list,
        use_dag_ontology_list=['ilsvrc_2012' in dataset_spec.name
                               for dataset_spec in dataset_spec_list],
        use_bilevel_ontology_list=[dataset_spec.name == 'omniglot'
                                   for dataset_spec in dataset_spec_list],
        split=getattr(learning_spec.Split, meta_split.upper()),
        episode_descr_config=config_lib.EpisodeDescriptionConfig(),
        pool=None,
        shuffle_buffer_size=data_config.shuffle_buffer_size,
        image_size=data_config.image_height,
        source_sampling_seed=seed + 1
    )
    md_episodes = list(
        md_episode_dataset.take(num_episodes).as_numpy_iterator())

    for (tfds_episode, tfds_source_id), (md_episode, md_source_id) in zip(
        tfds_episodes, md_episodes):

      np.testing.assert_equal(tfds_source_id, md_source_id)

      for tfds_tensor, md_tensor in zip(tfds_episode, md_episode):
        np.testing.assert_allclose(tfds_tensor, md_tensor)
Ejemplo n.º 3
0
def make_md(lst, method, split='train', image_size=126, **kwargs):
    if split == 'train':
        SPLIT = learning_spec.Split.TRAIN
    elif split == 'val':
        SPLIT = learning_spec.Split.VALID
    elif split == 'test':
        SPLIT = learning_spec.Split.TEST

    ALL_DATASETS = lst
    all_dataset_specs = []
    for dataset_name in ALL_DATASETS:
        dataset_records_path = os.path.join(BASE_PATH, dataset_name)
        dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path)
        all_dataset_specs.append(dataset_spec)

    if method == 'episodic':
        use_bilevel_ontology_list = [False]*len(ALL_DATASETS)
        use_dag_ontology_list = [False]*len(ALL_DATASETS)
        # Enable ontology aware sampling for Omniglot and ImageNet. 
        for i, s in enumerate(ALL_DATASETS):
            if s == 'ilsvrc_2012':
                use_dag_ontology_list[i] = True
            if s == 'omniglot':
                use_bilevel_ontology_list[i] = True
        variable_ways_shots = config.EpisodeDescriptionConfig(
            num_query=None, num_support=None, num_ways=None)

        dataset_episodic = pipeline.make_multisource_episode_pipeline(
            dataset_spec_list=all_dataset_specs,
            use_dag_ontology_list=use_dag_ontology_list,
            use_bilevel_ontology_list=use_bilevel_ontology_list,
            episode_descr_config=variable_ways_shots,
            split=SPLIT, image_size=image_size)
	
        return dataset_episodic

    elif method == 'batch':
        BATCH_SIZE = kwargs['batch_size']
        ADD_DATASET_OFFSET = False
        dataset_batch = pipeline.make_multisource_batch_pipeline(
            dataset_spec_list=all_dataset_specs, batch_size=BATCH_SIZE, split=SPLIT,
            image_size=image_size, add_dataset_offset=ADD_DATASET_OFFSET)
        return dataset_batch
Ejemplo n.º 4
0
    def _init_multi_source_dataset(self, items, split, episode_description):
        dataset_specs = self._get_dataset_spec(items)
        self.specs_dict[split] = dataset_specs

        use_bilevel_ontology_list = [False] * len(items)
        use_dag_ontology_list = [False] * len(items)
        # Enable ontology aware sampling for Omniglot and ImageNet.
        if 'omniglot' in items:
            use_bilevel_ontology_list[items.index('omniglot')] = True
        if 'ilsvrc_2012' in items:
            use_dag_ontology_list[items.index('ilsvrc_2012')] = True

        multi_source_pipeline = pipeline.make_multisource_episode_pipeline(
            dataset_spec_list=dataset_specs,
            use_dag_ontology_list=use_dag_ontology_list,
            use_bilevel_ontology_list=use_bilevel_ontology_list,
            split=split,
            episode_descr_config=episode_description,
            image_size=84)

        iterator = multi_source_pipeline.make_one_shot_iterator()
        return iterator.get_next()
Ejemplo n.º 5
0
    def test_make_multisource_episode_pipeline_feature(self, decoder_type,
                                                       config_file_path):

        # Create some feature records and write them to a temp directory.
        feat_size = 64
        num_examples = 100
        num_classes = 10
        output_path = self.get_temp_dir()
        gin.parse_config_file(config_file_path)

        # 1-Write feature records to temp directory.
        self.rng = np.random.RandomState(0)
        class_features = []
        class_examples = []
        for class_id in range(num_classes):
            features = self.rng.randn(num_examples,
                                      feat_size).astype(np.float32)
            label = np.array(class_id).astype(np.int64)
            output_file = os.path.join(output_path,
                                       str(class_id) + '.tfrecords')
            examples = test_utils.write_feature_records(
                features, label, output_file)
            class_examples.append(examples)
            class_features.append(features)
        class_examples = np.stack(class_examples)
        class_features = np.stack(class_features)

        # 2-Read records back using multi-source pipeline.
        # DatasetSpecification to use in tests
        dataset_spec = DatasetSpecification(
            name=None,
            classes_per_split={
                learning_spec.Split.TRAIN: 5,
                learning_spec.Split.VALID: 2,
                learning_spec.Split.TEST: 3
            },
            images_per_class={i: num_examples
                              for i in range(num_classes)},
            class_names=None,
            path=output_path,
            file_pattern='{}.tfrecords')

        # Duplicate the dataset to simulate reading from multiple datasets.
        use_bilevel_ontology_list = [False] * 2
        use_dag_ontology_list = [False] * 2
        all_dataset_specs = [dataset_spec] * 2

        fixed_ways_shots = config.EpisodeDescriptionConfig(num_query=5,
                                                           num_support=5,
                                                           num_ways=5)

        dataset_episodic = pipeline.make_multisource_episode_pipeline(
            dataset_spec_list=all_dataset_specs,
            use_dag_ontology_list=use_dag_ontology_list,
            use_bilevel_ontology_list=use_bilevel_ontology_list,
            episode_descr_config=fixed_ways_shots,
            split=learning_spec.Split.TRAIN,
            image_size=None)

        episode, _ = self.evaluate(
            dataset_episodic.make_one_shot_iterator().get_next())

        if decoder_type == 'feature':
            # 3-Check that support and query features are in class_features and have
            # the correct corresponding label.
            support_features, support_class_ids = episode[0], episode[2]
            query_features, query_class_ids = episode[3], episode[5]

            for feat, class_id in zip(list(support_features),
                                      list(support_class_ids)):
                abs_err = np.abs(
                    np.sum(class_features - feat[None][None], axis=-1))
                # Make sure the feature is present in the original data.
                self.assertEqual(abs_err.min(), 0.0)
                found_class_id = np.where(abs_err == 0.0)[0][0]
                self.assertEqual(found_class_id, class_id)

            for feat, class_id in zip(list(query_features),
                                      list(query_class_ids)):
                abs_err = np.abs(
                    np.sum(class_features - feat[None][None], axis=-1))
                # Make sure the feature is present in the original data.
                self.assertEqual(abs_err.min(), 0.0)
                found_class_id = np.where(abs_err == 0.0)[0][0]
                self.assertEqual(found_class_id, class_id)

        elif decoder_type == 'none':
            # 3-Check that support and query examples are in class_examples and have
            # the correct corresponding label.

            support_examples, support_class_ids = episode[0], episode[2]
            query_examples, query_class_ids = episode[3], episode[5]

            for example, class_id in zip(list(support_examples),
                                         list(support_class_ids)):
                found_class_id = np.where(class_examples == example)[0][0]
                self.assertEqual(found_class_id, class_id)

            for example, class_id in zip(list(query_examples),
                                         list(query_class_ids)):
                found_class_id = np.where(class_examples == example)[0][0]
                self.assertEqual(found_class_id, class_id)
Ejemplo n.º 6
0
    def build_episode(self, split):
        """Builds an EpisodeDataset containing the next data for "split".

    Args:
      split: A string, either 'train', 'valid', or 'test'.

    Returns:
      An EpisodeDataset.
    """
        shuffle_buffer_size = self.data_config.shuffle_buffer_size
        read_buffer_size_bytes = self.data_config.read_buffer_size_bytes
        benchmark_spec = (self.valid_benchmark_spec
                          if split == 'valid' else self.benchmark_spec)
        (_, image_shape, dataset_spec_list, has_dag_ontology,
         has_bilevel_ontology) = benchmark_spec
        episode_spec = self.split_episode_or_batch_specs[split]
        dataset_split, num_classes, num_train_examples, num_test_examples = \
            episode_spec
        # TODO(lamblinp): Support non-square shapes if necessary. For now, all
        # images are resized to square, even if it changes the aspect ratio.
        image_size = image_shape[0]
        if image_shape[1] != image_size:
            raise ValueError(
                'Expected a square image shape, not {}'.format(image_shape))

        # TODO(lamblinp): pass specs directly to the pipeline builder.
        # TODO(lamblinp): move the special case directly in make_..._pipeline
        if len(dataset_spec_list) == 1:

            use_dag_ontology = has_dag_ontology[0]
            if self.eval_finegrainedness or self.eval_imbalance_dataset:
                use_dag_ontology = False
            data_pipeline = pipeline.make_one_source_episode_pipeline(
                dataset_spec_list[0],
                use_dag_ontology=use_dag_ontology,
                use_bilevel_ontology=has_bilevel_ontology[0],
                split=dataset_split,
                num_ways=num_classes,
                num_support=num_train_examples,
                num_query=num_test_examples,
                shuffle_buffer_size=shuffle_buffer_size,
                read_buffer_size_bytes=read_buffer_size_bytes,
                image_size=image_size)
        else:
            data_pipeline = pipeline.make_multisource_episode_pipeline(
                dataset_spec_list,
                use_dag_ontology_list=has_dag_ontology,
                use_bilevel_ontology_list=has_bilevel_ontology,
                split=dataset_split,
                num_ways=num_classes,
                num_support=num_train_examples,
                num_query=num_test_examples,
                shuffle_buffer_size=shuffle_buffer_size,
                read_buffer_size_bytes=read_buffer_size_bytes,
                image_size=image_size)
            data_pipeline = apply_dataset_options(data_pipeline)

        iterator = data_pipeline.make_one_shot_iterator()
        (support_images, support_labels, support_class_ids, query_images,
         query_labels, query_class_ids) = iterator.get_next()
        return providers.EpisodeDataset(train_images=support_images,
                                        test_images=query_images,
                                        train_labels=support_labels,
                                        test_labels=query_labels,
                                        train_class_ids=support_class_ids,
                                        test_class_ids=query_class_ids)
Ejemplo n.º 7
0
    def test_make_multisource_episode_pipeline_feature(self):
        def iterate_dataset(dataset, n):
            """Iterate over dataset."""
            if not tf.executing_eagerly():
                iterator = dataset.make_one_shot_iterator()
                next_element = iterator.get_next()
                with tf.Session() as sess:
                    for idx in range(n):
                        yield idx, sess.run(next_element)
            else:
                for idx, episode in enumerate(dataset):
                    if idx == n:
                        break
                    yield idx, episode

        def write_feature_records(features, label, output_path):
            """Create a record file from features and labels.

      Args:
        features: An [n, m] numpy array of features.
        label: A numpy array containing the label.
        output_path: A string specifying the location of the record.
      """
            writer = tf.python_io.TFRecordWriter(output_path)
            with self.session(use_gpu=False) as sess:
                for feat in list(features):
                    feat_serial = sess.run(tf.io.serialize_tensor(feat))
                    # Write the example.
                    dataset_to_records.write_example(
                        feat_serial,
                        label,
                        writer,
                        input_key='image/embedding',
                        label_key='image/class/label')
            writer.close()

        # Create some feature records and write them to a temp directory.
        feat_size = 64
        num_examples = 100
        num_classes = 10
        output_path = self.get_temp_dir()
        gin.parse_config("""
        import meta_dataset.data.decoder
        EpisodeDescriptionConfig.min_ways = 5
        EpisodeDescriptionConfig.max_ways_upper_bound = 50
        EpisodeDescriptionConfig.max_num_query = 10
        EpisodeDescriptionConfig.max_support_set_size = 500
        EpisodeDescriptionConfig.max_support_size_contrib_per_class = 100
        EpisodeDescriptionConfig.min_log_weight = -0.69314718055994529  # np.log(0.5)
        EpisodeDescriptionConfig.max_log_weight = 0.69314718055994529  # np.log(2)
        EpisodeDescriptionConfig.ignore_dag_ontology = False
        EpisodeDescriptionConfig.ignore_bilevel_ontology = False
        process_episode.support_decoder = @FeatureDecoder()
        process_episode.query_decoder = @FeatureDecoder()
        """)

        # 1-Write feature records to temp directory.
        self.rng = np.random.RandomState(0)
        class_features = []
        for class_id in range(num_classes):
            features = self.rng.randn(num_examples,
                                      feat_size).astype(np.float32)
            label = np.array(class_id).astype(np.int64)
            output_file = os.path.join(output_path,
                                       str(class_id) + '.tfrecords')
            write_feature_records(features, label, output_file)
            class_features.append(features)
        class_features = np.stack(class_features)

        # 2-Read records back using multi-source pipeline.
        # DatasetSpecification to use in tests
        dataset_spec = DatasetSpecification(
            name=None,
            classes_per_split={
                learning_spec.Split.TRAIN: 5,
                learning_spec.Split.VALID: 2,
                learning_spec.Split.TEST: 3
            },
            images_per_class={i: num_examples
                              for i in range(num_classes)},
            class_names=None,
            path=output_path,
            file_pattern='{}.tfrecords')

        # Duplicate the dataset to simulate reading from multiple datasets.
        use_bilevel_ontology_list = [False] * 2
        use_dag_ontology_list = [False] * 2
        all_dataset_specs = [dataset_spec] * 2

        fixed_ways_shots = config.EpisodeDescriptionConfig(num_query=5,
                                                           num_support=5,
                                                           num_ways=5)

        dataset_episodic = pipeline.make_multisource_episode_pipeline(
            dataset_spec_list=all_dataset_specs,
            use_dag_ontology_list=use_dag_ontology_list,
            use_bilevel_ontology_list=use_bilevel_ontology_list,
            episode_descr_config=fixed_ways_shots,
            split=learning_spec.Split.TRAIN,
            image_size=None)

        _, episode = next(iterate_dataset(dataset_episodic, 1))
        # 3-Check that support and query features are in class_features and have
        # the correct corresponding label.
        support_features, support_class_ids = episode[0], episode[2]
        query_features, query_class_ids = episode[3], episode[5]

        for feat, class_id in zip(list(support_features),
                                  list(support_class_ids)):
            abs_err = np.abs(np.sum(class_features - feat[None][None],
                                    axis=-1))
            # Make sure the feature is present in the original data.
            self.assertEqual(abs_err.min(), 0.0)
            found_class_id = np.where(abs_err == 0.0)[0][0]
            self.assertEqual(found_class_id, class_id)

        for feat, class_id in zip(list(query_features), list(query_class_ids)):
            abs_err = np.abs(np.sum(class_features - feat[None][None],
                                    axis=-1))
            # Make sure the feature is present in the original data.
            self.assertEqual(abs_err.min(), 0.0)
            found_class_id = np.where(abs_err == 0.0)[0][0]
            self.assertEqual(found_class_id, class_id)
Ejemplo n.º 8
0
## (1) Episodic Mode
use_bilevel_ontology_list = [False] * len(ALL_DATASETS)
use_dag_ontology_list = [False] * len(ALL_DATASETS)

# Enable ontology aware sampling for Omniglot and ImageNet.
#use_bilevel_ontology_list[5] = True
#use_dag_ontology_list[4] = True
variable_ways_shots = config.EpisodeDescriptionConfig(num_query=None,
                                                      num_support=None,
                                                      num_ways=None)

dataset_episodic = pipeline.make_multisource_episode_pipeline(
    dataset_spec_list=all_dataset_specs,
    use_dag_ontology_list=use_dag_ontology_list,
    use_bilevel_ontology_list=use_bilevel_ontology_list,
    episode_descr_config=variable_ways_shots,
    split=SPLIT,
    image_size=128)

## (4) Using Meta-dataset with PyTorch
import torch
# 1
to_torch_labels = lambda a: torch.from_numpy(a.numpy()).long()
to_torch_imgs = lambda a: torch.from_numpy(
    np.transpose(a.numpy(), (0, 3, 1, 2)))


# 2
def data_loader(n_batches):
    for i, (e, _) in enumerate(dataset_episodic):