Ejemplo n.º 1
0
 def _get_dataset_spec(self, items):
     if isinstance(items, list):
         dataset_specs = []
         for dataset_name in items:
             dataset_records_path = os.path.join(self.data_path, dataset_name)
             dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path)
             dataset_specs.append(dataset_spec)
         return dataset_specs
     else:
         dataset_name = items
         dataset_records_path = os.path.join(self.data_path, dataset_name)
         dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path)
         return dataset_spec
Ejemplo n.º 2
0
def get_synsets_from_class_ids(class_ids):
    """Returns the Synsets of the appropriate subgraph corresponding to class_ids.

  For each class id in class_ids, the corresponding Synset is found among the
  Synsets of the subgraph corresponding to the split that is chosen for the
  fine-grainedness analysis.

  Args:
    class_ids: A np.array of ints in the range between 1 and the total number of
      classes that contains the two class id's chosen for an episode.

  Returns:
    A list of Synsets.

  Raises:
    ValueError: The dataset specification is not found in the expected location.
  """
    # First load the DatasetSpecification of ImageNet.
    dataset_records_path = os.path.join(FLAGS.records_root_dir, 'ilsvrc_2012')
    imagenet_data_spec = dataset_spec.load_dataset_spec(dataset_records_path)

    # A set of Synsets of the split's subgraph.
    split_enum = get_finegrainedness_split_enum()
    split_subgraph = imagenet_data_spec.split_subgraphs[split_enum]

    # Go from class_ids (integers in the range from 1 to the total number of
    # classes in the Split) to WordNet id's, e.g n02075296.
    wn_ids = []
    for class_id in class_ids:
        wn_ids.append(imagenet_data_spec.class_names[class_id])

    # Find the Synsets in split_subgraph whose WordNet id's are wn_ids.
    synsets = imagenet_spec.get_synsets_from_ids(wn_ids, split_subgraph)
    return [synsets[wn_id] for wn_id in wn_ids]
Ejemplo n.º 3
0
    def _init_multi_source_dataset(self, items, split, episode_description):
        dataset_specs = []
        for dataset_name in items:
            dataset_records_path = os.path.join(self.data_path, dataset_name)
            dataset_spec = dataset_spec_lib.load_dataset_spec(
                dataset_records_path)
            dataset_specs.append(dataset_spec)

        use_bilevel_ontology_list = [False] * len(items)
        use_dag_ontology_list = [False] * len(items)
        # Enable ontology aware sampling for Omniglot and ImageNet.
        if 'omniglot' in items:
            use_bilevel_ontology_list[items.index('omniglot')] = True
        if 'ilsvrc_2012' in items:
            use_dag_ontology_list[items.index('ilsvrc_2012')] = True

        multi_source_pipeline = pipeline.make_multisource_episode_pipeline(
            dataset_spec_list=dataset_specs,
            use_dag_ontology_list=use_dag_ontology_list,
            use_bilevel_ontology_list=use_bilevel_ontology_list,
            split=split,
            episode_descr_config=episode_description,
            image_size=84)

        iterator = multi_source_pipeline.make_one_shot_iterator()
        return iterator.get_next()
Ejemplo n.º 4
0
    def _init_single_source_dataset(self, dataset_name, split,
                                    episode_description):
        dataset_records_path = os.path.join(self.data_path, dataset_name)
        dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path)

        # Enable ontology aware sampling for Omniglot and ImageNet.
        use_bilevel_ontology = False
        if 'omniglot' in dataset_name:
            use_bilevel_ontology = True

        use_dag_ontology = False
        if 'ilsvrc_2012' in dataset_name:
            use_dag_ontology = True

        single_source_pipeline = pipeline.make_one_source_episode_pipeline(
            dataset_spec=dataset_spec,
            use_dag_ontology=use_dag_ontology,
            use_bilevel_ontology=use_bilevel_ontology,
            split=split,
            episode_descr_config=episode_description,
            image_size=84,
            shuffle_buffer_size=1000)

        iterator = single_source_pipeline.make_one_shot_iterator()
        return iterator.get_next()
Ejemplo n.º 5
0
  def __init__(self, datasets, split, fixed_ways=None, fixed_support=None,
               fixed_query=None, use_ontology=True):
    assert split in ['train', 'valid', 'test']
    assert isinstance(datasets, (list, tuple))
    assert isinstance(datasets[0], str)

    print(f'Loading MetaDataset for {split}..')
    split = getattr(learning_spec.Split, split.upper())
    # Reading datasets
    # self.datasets = ['aircraft', 'cu_birds', 'dtd', 'fungi', 'ilsvrc_2012',
    # 'omniglot', 'quickdraw', 'vgg_flower']
    self.datasets = datasets

    # Ontology setting
    use_bilevel_ontology_list = []
    use_dag_ontology_list = []

    for dataset in self.datasets:
      bilevel = dag = False
      if dataset == 'omniglot' and use_ontology:
        bilevel = True
      elif dataset == 'ilsvrc_2012' and use_ontology:
        dag = True
      use_bilevel_ontology_list.append(bilevel)
      use_dag_ontology_list.append(dag)

    assert len(self.datasets) == len(use_bilevel_ontology_list)
    assert len(self.datasets) == len(use_dag_ontology_list)

    all_dataset_specs = []
    for dataset_name in self.datasets:
      dataset_records_path = os.path.join(BASE_PATH, dataset_name)
      dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path)
      all_dataset_specs.append(dataset_spec)

    if fixed_ways and use_ontology:
      max_ways_upper_bound = min_ways = fixed_ways
      self.episode_config = config.EpisodeDescriptionConfig(
          num_query=fixed_query, num_support=fixed_support, min_ways=min_ways,
          max_ways_upper_bound=max_ways_upper_bound, num_ways=None)
    else:
      # Episode description config (if num is None, use gin configuration)
      self.episode_config = config.EpisodeDescriptionConfig(
          num_query=fixed_query, num_support=fixed_support, num_ways=fixed_ways)

    # Episode pipeline
    self.episode_pipeline, self.n_total_classes = \
        pipeline.make_multisource_episode_pipeline2(
            dataset_spec_list=all_dataset_specs,
            use_dag_ontology_list=use_dag_ontology_list,
            use_bilevel_ontology_list=use_bilevel_ontology_list,
            episode_descr_config=self.episode_config,
            split=split, image_size=84)

    print('MetaDataset loaded: ', ', '.join([d for d in self.datasets]))
Ejemplo n.º 6
0
  def test_meta_dataset(self):
    gin.parse_config_file(tfds.core.as_path(meta_dataset.__file__).parent /
                          'learn/gin/setups/data_config_tfds.gin')
    gin.parse_config(_DETERMINISTIC_CONFIG)
    data_config = config_lib.DataConfig()
    seed = 20210917
    num_episodes = 10
    meta_split = 'valid'
    md_sources = ('aircraft', 'cu_birds', 'vgg_flower')

    sampling.RNG = np.random.RandomState(seed)
    tfds_episode_dataset = api.meta_dataset(
        md_sources=md_sources, md_version='v1', meta_split=meta_split,
        source_sampling_seed=seed + 1, data_dir=FLAGS.tfds_path)
    tfds_episodes = list(
        tfds_episode_dataset.take(num_episodes).as_numpy_iterator())

    sampling.RNG = np.random.RandomState(seed)
    dataset_spec_list = [
        dataset_spec_lib.load_dataset_spec(os.path.join(FLAGS.meta_dataset_path,
                                                        md_source))
        for md_source in md_sources
    ]
    # We should not skip TFExample decoding in the original Meta-Dataset
    # implementation. The kwarg defaults to False when the class is defined, but
    # the call to `gin.parse_config_file` above changes the default value to
    # True, which is why we have to explicitly bind a new default value here.
    gin.bind_parameter('ImageDecoder.skip_tfexample_decoding', False)
    md_episode_dataset = pipeline.make_multisource_episode_pipeline(
        dataset_spec_list=dataset_spec_list,
        use_dag_ontology_list=['ilsvrc_2012' in dataset_spec.name
                               for dataset_spec in dataset_spec_list],
        use_bilevel_ontology_list=[dataset_spec.name == 'omniglot'
                                   for dataset_spec in dataset_spec_list],
        split=getattr(learning_spec.Split, meta_split.upper()),
        episode_descr_config=config_lib.EpisodeDescriptionConfig(),
        pool=None,
        shuffle_buffer_size=data_config.shuffle_buffer_size,
        image_size=data_config.image_height,
        source_sampling_seed=seed + 1
    )
    md_episodes = list(
        md_episode_dataset.take(num_episodes).as_numpy_iterator())

    for (tfds_episode, tfds_source_id), (md_episode, md_source_id) in zip(
        tfds_episodes, md_episodes):

      np.testing.assert_equal(tfds_source_id, md_source_id)

      for tfds_tensor, md_tensor in zip(tfds_episode, md_episode):
        np.testing.assert_allclose(tfds_tensor, md_tensor)
Ejemplo n.º 7
0
    def _init_dataset(self, dataset, split, episode_description):
        dataset_records_path = os.path.join(self.data_path, dataset)
        dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path)

        single_source_pipeline = pipeline.make_one_source_episode_pipeline(
            dataset_spec=dataset_spec,
            use_dag_ontology=False,
            use_bilevel_ontology=False,
            split=split,
            episode_descr_config=episode_description,
            image_size=84)

        iterator = single_source_pipeline.make_one_shot_iterator()
        return iterator.get_next()
Ejemplo n.º 8
0
def main(unused_argv):
    logging.info(FLAGS.output_dir)
    tf.io.gfile.makedirs(FLAGS.output_dir)
    gin.parse_config_files_and_bindings(FLAGS.gin_config,
                                        FLAGS.gin_bindings,
                                        finalize_config=True)
    dataset_spec = dataset_spec_lib.load_dataset_spec(
        os.path.join(FLAGS.records_root_dir, FLAGS.dataset_name))
    data_config = config.DataConfig()
    episode_descr_config = config.EpisodeDescriptionConfig()
    use_dag_ontology = (FLAGS.dataset_name in ('ilsvrc_2012', 'ilsvrc_2012_v2')
                        and not FLAGS.ignore_dag_ontology)
    use_bilevel_ontology = (FLAGS.dataset_name == 'omniglot'
                            and not FLAGS.ignore_bilevel_ontology)
    data_pipeline = pipeline.make_one_source_episode_pipeline(
        dataset_spec,
        use_dag_ontology=use_dag_ontology,
        use_bilevel_ontology=use_bilevel_ontology,
        split=FLAGS.split,
        episode_descr_config=episode_descr_config,
        # TODO(evcu) Maybe set the following to 0 to prevent shuffling and check
        # reproducibility of dumping.
        shuffle_buffer_size=data_config.shuffle_buffer_size,
        read_buffer_size_bytes=data_config.read_buffer_size_bytes,
        num_prefetch=data_config.num_prefetch)
    dataset = data_pipeline.take(FLAGS.num_episodes)

    images_per_class_dict = {}
    # Ignoring dataset number since we are loading one dataset.
    for episode_number, (episode, _) in enumerate(dataset):
        logging.info('Dumping episode %d', episode_number)
        train_imgs, train_labels, _, test_imgs, test_labels, _ = episode
        path_train = utils.get_file_path(FLAGS.output_dir, episode_number,
                                         'train')
        path_test = utils.get_file_path(FLAGS.output_dir, episode_number,
                                        'test')
        utils.dump_as_tfrecord(path_train, train_imgs, train_labels)
        utils.dump_as_tfrecord(path_test, test_imgs, test_labels)
        images_per_class_dict[os.path.basename(path_train)] = (
            utils.get_label_counts(train_labels))
        images_per_class_dict[os.path.basename(path_test)] = (
            utils.get_label_counts(test_labels))
    info_path = utils.get_info_path(FLAGS.output_dir)
    with tf.io.gfile.GFile(info_path, 'w') as f:
        f.write(json.dumps(images_per_class_dict, indent=2))
Ejemplo n.º 9
0
def make_md(lst, method, split='train', image_size=126, **kwargs):
    if split == 'train':
        SPLIT = learning_spec.Split.TRAIN
    elif split == 'val':
        SPLIT = learning_spec.Split.VALID
    elif split == 'test':
        SPLIT = learning_spec.Split.TEST

    ALL_DATASETS = lst
    all_dataset_specs = []
    for dataset_name in ALL_DATASETS:
        dataset_records_path = os.path.join(BASE_PATH, dataset_name)
        dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path)
        all_dataset_specs.append(dataset_spec)

    if method == 'episodic':
        use_bilevel_ontology_list = [False]*len(ALL_DATASETS)
        use_dag_ontology_list = [False]*len(ALL_DATASETS)
        # Enable ontology aware sampling for Omniglot and ImageNet. 
        for i, s in enumerate(ALL_DATASETS):
            if s == 'ilsvrc_2012':
                use_dag_ontology_list[i] = True
            if s == 'omniglot':
                use_bilevel_ontology_list[i] = True
        variable_ways_shots = config.EpisodeDescriptionConfig(
            num_query=None, num_support=None, num_ways=None)

        dataset_episodic = pipeline.make_multisource_episode_pipeline(
            dataset_spec_list=all_dataset_specs,
            use_dag_ontology_list=use_dag_ontology_list,
            use_bilevel_ontology_list=use_bilevel_ontology_list,
            episode_descr_config=variable_ways_shots,
            split=SPLIT, image_size=image_size)
	
        return dataset_episodic

    elif method == 'batch':
        BATCH_SIZE = kwargs['batch_size']
        ADD_DATASET_OFFSET = False
        dataset_batch = pipeline.make_multisource_batch_pipeline(
            dataset_spec_list=all_dataset_specs, batch_size=BATCH_SIZE, split=SPLIT,
            image_size=image_size, add_dataset_offset=ADD_DATASET_OFFSET)
        return dataset_batch
Ejemplo n.º 10
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Report failed checks as they occur and maintain a counter, instead of
    # raising exceptions right away, so all issues can be reported at once.
    num_failed_checks = 0

    # Load dataset_spec, this should fail if it is absent or incorrect.
    if FLAGS.dataset_spec_file is None:
        dataset_spec = dataset_spec_lib.load_dataset_spec(
            FLAGS.dataset_records_path)
    else:
        with tf.io.gfile.GFile(FLAGS.dataset_spec_file, 'r') as f:
            dataset_spec = json.load(
                f, object_hook=dataset_spec_lib.as_dataset_spec)

    dataset_spec.initialize()

    # 1. Check dataset name
    dir_name = os.path.basename(os.path.abspath(FLAGS.dataset_records_path))
    if dataset_spec.name != dir_name:
        num_failed_checks += 1
        logging.error(
            'The dataset name in "dataset_spec.json" (%s) does not match '
            'the name of the directory containing it (%s)', dataset_spec.name,
            dir_name)

    # 2. Check name and number of .tfrecords files
    num_classes = len(dataset_spec.class_names)
    try:
        expected_filenames = [
            dataset_spec.file_pattern.format(class_id)
            for class_id in range(num_classes)
        ]
    except IndexError:
        num_failed_checks += 1
        err_msg = (
            'The `file_pattern` (%s) did not accept the class number as its only '
            'formatting argument. Using the default (%s).')
        default_pattern = '{}.tfrecords'
        logging.error(err_msg, dataset_spec.file_pattern, default_pattern)

        expected_filenames = [
            default_pattern.format(class_id) for class_id in range(num_classes)
        ]

    all_filenames = tf.io.gfile.listdir(FLAGS.dataset_records_path)
    # Heuristic to exclude obviously-not-tfrecords files.
    tfrecords_filenames = [
        f for f in all_filenames if 'tfrecords' in f.lower()
    ]

    expected_set = set(expected_filenames)
    present_set = set(tfrecords_filenames)
    if set(expected_set) != set(present_set):

        num_failed_checks += 1
        logging.error(
            'The tfrecords files in %s do not match the dataset_spec.\n'
            'Unexpected files present:\n'
            '%s\n'
            'Expected files not present:\n'
            '%s', FLAGS.dataset_records_path,
            sorted(present_set - expected_set),
            sorted(expected_set - present_set))

    # Iterate through each dataset, count examples and check set of targets.
    # List of (class_id, expected_count, actual_count) triples.
    bad_counts = []
    # List of (filename, class_id, labels).
    bad_labels = []

    for class_id, filename in enumerate(expected_filenames):
        expected_count = dataset_spec.get_total_images_per_class(class_id)
        if filename not in tfrecords_filenames:
            # The tfrecords does not exist, we use a negative count to denote it.
            bad_counts.append((class_id, expected_count, -1))
            bad_labels.append((filename, class_id, set()))
            continue
        full_filepath = os.path.join(FLAGS.dataset_records_path, filename)

        try:
            count, labels = get_count_and_labels(full_filepath,
                                                 FLAGS.label_field_name)
        except tf.errors.InvalidArgumentError:
            logging.exception(
                'Unable to find label (%s) in the tf.Examples of file %s. '
                'Maybe try a different --label_field_name.',
                FLAGS.label_field_name, filename)
            # Fall back to counting examples only.
            count = count_records(full_filepath)
            labels = set()
        if count != expected_count:
            bad_counts.append((class_id, expected_count, count))
        if labels != {class_id}:
            # labels could include class_id among other, incorrect labels.
            bad_labels.append((filename, class_id, labels))

    # 3. Check number of examples
    if bad_counts:
        num_failed_checks += 1
        logging.error(
            'The number of tfrecords in the following files do not match '
            'the expected number of examples in that class.\n'
            '(filename, expected, actual)  # -1 denotes a missing file.\n'
            '%s', bad_counts)

    # 4. Check the targets stored in the tfrecords files.
    if bad_labels:
        num_failed_checks += 1
        logging.error(
            'The labels stored inside the tfrecords (in field %s) do not '
            'all match the expected value (class_id).\n'
            '(filename, class_id, values)\n'
            '%s', FLAGS.label_field_name, bad_labels)

    # Report results
    if num_failed_checks:
        raise ValueError('%d checks failed. See the error-level logs.' %
                         num_failed_checks)
Ejemplo n.º 11
0
    def __init__(self,
                 path_to_records,
                 batch_config=None,
                 episode_config=[28, 5, 1, 19],
                 valid_episode_config=[28, 5, 1, 19],
                 pool='train',
                 mode='episode'):
        """
        Args:
            path_to_records: Absolute path of the tfrecords from which data
                will be generated. Should have the form ~/meta_train or 
                ~/meta_test
            batch_config: Array-like. In batch mode, controls 
                the size of batches and decoded images generated.Ignored 
                otherwise. [image_size, batch_size]
            episode_config: Array-like. Describes the episode configuration. If
                pool='train' and mode='episode', it sets the meta-train 
                episodes configuration generator. If pool='test', it sets the
                meta-test episodes configuration generator.
                [image_size, num_ways, num_examples_per_class_in_support, 
                    num_total_query_examples] 
            valid_episode_config: Array-like. Sets the episode configuration 
                for the meta-valid episodes generator. It is only relevant in 
                the pool='train' and mode='episode' setting, since
                both meta-train and meta-valid split could have their own 
                episodes configuration.
            pool: The split from which images are taken. Only 'train' and 'test'
                pool are allowed. Automatically create a meta-validation 
                generator if pool='train'.
            mode: The configuration of the data coming from the meta-train 
                generator. 'batch' and 'episode' are available.
        """
        self.episode_config = episode_config
        self.valid_episode_config = valid_episode_config
        self.pool = pool
        self.mode = mode

        if self.pool not in ['train', 'test']:
            raise ValueError(
                ('In DataGenerator, only \'train\' or \'test\' ' +
                 'are valid arguments for pool. Received :{}').format(
                     self.pool))
        if self.mode not in ['episode', 'batch']:
            raise ValueError(
                ('In DataGenerator, only \'episode\' or \'batch\' ' +
                 'are valid arguments for mode. Received :{}').format(
                     self.mode))
        if (self.pool == 'test' and self.mode == 'batch'):
            raise ValueError((
                'In DataGenerator, batch mode is only available ' +
                'at meta-train time. Received pool : {} and mode : {}').format(
                    self.pool, self.mode))
        if self.mode == 'batch':
            try:
                self.image_size_batch = batch_config[0]
                self.batch_size = batch_config[1]
            except:
                raise ValueError(
                    ('The batch_config argument in DataGenerator ' +
                     'is not defined properly. Make sure it has the form ' +
                     '[img_size, batch_size]. ' +
                     'Received batch_config : {}').format(batch_config))
        if self.mode == 'episode':
            try:
                _, _, _, _ = (self.episode_config[0], self.episode_config[1],
                              self.episode_config[2], self.episode_config[3])
            except:
                raise ValueError(
                    ('The episode config argument in DataGenerator ' +
                     'is not defined properly. Make sure it has the form ' +
                     '[img_size, num_ways, num_shots, num_query]. ' +
                     'Received episode_config : {}').format(episode_config))

            try:
                _, _, _, _ = (self.valid_episode_config[0],
                              self.valid_episode_config[1],
                              self.valid_episode_config[2],
                              self.valid_episode_config[3])
            except:
                raise ValueError(
                    ('The episode config argument in DataGenerator ' +
                     'is not defined properly. Make sure it has the form ' +
                     '[img_size, num_ways, num_shots, num_query]. ' +
                     'Received episode_config : {}'
                     ).format(valid_episode_config))

        self.dataset_spec = dataset_spec_lib.load_dataset_spec(path_to_records)

        # Loading root path.
        root_path = os.path.join(os.path.dirname(__file__), os.pardir,
                                 os.pardir)
        gin_path = os.path.join(root_path, 'metadl/gin/default/decoders.gin')
        gin.parse_config_file(gin_path)

        self.meta_train_pipeline = None
        self.meta_test_pipeline = None
        self.meta_valid_pipeline = None

        logging.info('Creating {} generator for meta-{} dataset.'.format(
            self.mode, self.pool))

        self.set_fixed_episode_config()
        if self.pool == 'train':
            if self.mode == 'episode':
                self.generate_meta_train_episodes_pipeline()
            else:
                self.generate_meta_train_batch_pipeline()
        else:
            self.generate_meta_test_episodes_pipeline()
Ejemplo n.º 12
0
  def test_full_ways_dataset(self, source, md_source, md_version, meta_split,
                             remap_labels):

    dataset_spec = dataset_spec_lib.load_dataset_spec(
        os.path.join(FLAGS.meta_dataset_path, md_source))
    allowed_classes = set()
    forbidden_classes = set()
    for ms in ('train', 'valid', 'test'):
      (allowed_classes if ms == meta_split else forbidden_classes).update(
          dataset_spec.get_classes(getattr(learning_spec.Split, ms.upper())))

    # For MD-v2's ilsvrc_2012 source, train classes are sorted by class name,
    # whereas the TFDS implementation intentionally keeps the v1 class order.
    if remap_labels:
      # The first argsort tells us where to look in class_names for position i
      # in the sorted class list. The second argsort reverses that: it tells us,
      # for position j in class_names, where to place that class in the sorted
      # class list.
      class_names = tfds.builder(
          'meta_dataset/ilsvrc_2012',
          data_dir=FLAGS.tfds_path
      ).info.metadata['class_names']
      label_map = np.argsort(np.argsort(class_names))
      label_remap_fn = np.vectorize(lambda x: label_map[x])

    total_images = sum(dataset_spec.get_total_images_per_class(l)
                       for l in allowed_classes)
    shuffle_buffer_size = min(10_000, total_images)
    batch_size = min(1024, total_images)
    dataset = api.full_ways_dataset(
        md_source=source,
        md_version=md_version,
        meta_split=meta_split,
        decoders={'image': tfds.decode.SkipDecoding()},
        data_dir=FLAGS.tfds_path,
    ).shuffle(
        shuffle_buffer_size
    ).batch(batch_size, drop_remainder=False).prefetch(1)

    entropies = []
    for batch in dataset.as_numpy_iterator():
      labels = batch['label']
      if remap_labels:
        labels = label_remap_fn(labels)

      # With the exception of the remainder batch, we expect classes to be
      # distributed more or less uniformly within batches.
      if len(labels) == batch_size:
        label_counts = collections.Counter(labels)
        p = np.array(
            [label_counts[l] for l in allowed_classes], dtype='float32'
        ) / batch_size
        entropies.append(-np.nan_to_num(p * np.log(p)).sum())

      label_set = set(labels)
      self.assertContainsSubset(label_set, allowed_classes)
      self.assertNoCommonElements(label_set, forbidden_classes)

    # We are happy if at least 75% of the (non-remainder) batches have a
    # class label entropy is at least 50% of the maximum value possible (which
    # corresponds to a uniform distribution).
    self.assertGreater(
        (np.array(entropies) > 0.5 * np.log(len(allowed_classes))).mean(),
        0.75)
Ejemplo n.º 13
0
  def test_episode_dataset_matches_meta_dataset(self,
                                                source,
                                                md_source,
                                                md_version,
                                                meta_split,
                                                remap_labels):

    gin.parse_config_file(tfds.core.as_path(meta_dataset.__file__).parent /
                          'learn/gin/setups/data_config_tfds.gin')
    gin.parse_config(_DETERMINISTIC_CONFIG)
    data_config = config_lib.DataConfig()
    seed = 20210917
    num_episodes = 10

    builder = tfds.builder(
        'meta_dataset', config=source, data_dir=FLAGS.tfds_path)
    sampling.RNG = np.random.RandomState(seed)
    tfds_episode_dataset = api.episode_dataset(
        builder, md_version, meta_split, source_id=0)
    # For MD-v2's ilsvrc_2012 source, train classes are sorted by class name,
    # whereas the TFDS implementation intentionally keeps the v1 class order.
    if remap_labels:
      # The first argsort tells us where to look in class_names for position i
      # in the sorted class list. The second argsort reverses that: it tells us,
      # for position j in class_names, where to place that class in the sorted
      # class list.
      label_map = np.argsort(np.argsort(builder.info.metadata['class_names']))
      label_remap_fn = np.vectorize(lambda x: label_map[x])
    tfds_episodes = list(
        tfds_episode_dataset.take(num_episodes).as_numpy_iterator())

    dataset_spec = dataset_spec_lib.load_dataset_spec(
        os.path.join(FLAGS.meta_dataset_path, md_source))
    sampling.RNG = np.random.RandomState(seed)
    # We should not skip TFExample decoding in the original Meta-Dataset
    # implementation.
    gin.bind_parameter('ImageDecoder.skip_tfexample_decoding', False)
    md_episode_dataset = pipeline.make_one_source_episode_pipeline(
        dataset_spec,
        use_dag_ontology='ilsvrc_2012' in dataset_spec.name,
        use_bilevel_ontology=dataset_spec.name == 'omniglot',
        split=getattr(learning_spec.Split, meta_split.upper()),
        episode_descr_config=config_lib.EpisodeDescriptionConfig(),
        pool=None,
        shuffle_buffer_size=data_config.shuffle_buffer_size,
        image_size=data_config.image_height
    )
    md_episodes = list(
        md_episode_dataset.take(num_episodes).as_numpy_iterator())

    for (tfds_episode, tfds_source_id), (md_episode, md_source_id) in zip(
        tfds_episodes, md_episodes):

      np.testing.assert_equal(tfds_source_id, md_source_id)

      if remap_labels:
        tfds_episode = list(tfds_episode)
        tfds_episode[2] = label_remap_fn(tfds_episode[2])
        tfds_episode[5] = label_remap_fn(tfds_episode[5])
        tfds_episode = tuple(tfds_episode)

      for tfds_tensor, md_tensor in zip(tfds_episode, md_episode):
        np.testing.assert_allclose(tfds_tensor, md_tensor)
Ejemplo n.º 14
0
                break
            yield idx, episode


# 5
SPLIT = learning_spec.Split.TRAIN

### Reading datasets
#ALL_DATASETS = ['aircraft', 'cu_birds', 'dtd', 'fungi', 'ilsvrc_2012',
#                'omniglot', 'quickdraw', 'vgg_flower']
ALL_DATASETS = ['cu_birds']

all_dataset_specs = []
for dataset_name in ALL_DATASETS:
    dataset_records_path = os.path.join(BASE_PATH, dataset_name)
    dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path)
    all_dataset_specs.append(dataset_spec)

## (1) Episodic Mode
use_bilevel_ontology_list = [False] * len(ALL_DATASETS)
use_dag_ontology_list = [False] * len(ALL_DATASETS)

# Enable ontology aware sampling for Omniglot and ImageNet.
#use_bilevel_ontology_list[5] = True
#use_dag_ontology_list[4] = True
variable_ways_shots = config.EpisodeDescriptionConfig(num_query=None,
                                                      num_support=None,
                                                      num_ways=None)

dataset_episodic = pipeline.make_multisource_episode_pipeline(
    dataset_spec_list=all_dataset_specs,