Esempio n. 1
0
    def _init_single_source_dataset(self, dataset_name, split):
        dataset_specs = self._get_dataset_spec(dataset_name)
        self.specs_dict[split] = dataset_specs
        multi_source_pipeline = pipeline.make_one_source_batch_pipeline(
            dataset_spec=dataset_specs, batch_size=self.batch_size,
            split=split, image_size=84)

        iterator = multi_source_pipeline.make_one_shot_iterator()
        return iterator.get_next()
Esempio n. 2
0
    def generate_meta_train_batch_pipeline(self):
        """ Creates a batch generator for examples coming from the meta-train 
        split. Also creates an episode generator for examples coming from the 
        meta-valid split. Indeed, the way data comes from the meta-valid split
        should match its meta-test split counter-part.

        The meta-valid data generator will use the episode_config for the 
        episodes description if its own configuration is not provided.
        ----------------------------------------------------------------------
        Details of some arguments inside the function : 

        The following arguments are ignored since we are using fixed episode
        description : 
            - max_ways_upper_bound, max_num_query, max_support_{}, 
                min/max_log_weight

        dag_ontology : only relevant for ImageNet dataset. Ignored.
        bilevel_ontology : Whether to ignore Omniglot's DAG ontology when
            sampling classes from it. Ignored.
        min_examples : 0 means that no class is discarded from having a too 
            small number of examples.
        """
        self.meta_train_pipeline = pipeline.make_one_source_batch_pipeline(
            dataset_spec=self.dataset_spec,
            split=learning_spec.Split.TRAIN,
            batch_size= self.batch_size,
            pool=None,
            shuffle_buffer_size=None,
            read_buffer_size_bytes=None,
            num_prefetch=0,
            image_size=self.image_size_batch,
            num_to_take=None
        )
        self.meta_valid_pipeline = pipeline.make_one_source_episode_pipeline(
            dataset_spec=self.dataset_spec,
            use_dag_ontology=False,
            use_bilevel_ontology=False,
            split=learning_spec.Split.VALID,
            episode_descr_config=self.fixed_ways_shots_valid,
            pool=None,
            shuffle_buffer_size=None,
            read_buffer_size_bytes=None,
            num_prefetch=0,
            image_size=self.valid_episode_config[0],
            num_to_take=None
        )
        logging.info('Meta-valid episode config : {}'.format(
            self.valid_episode_config))
        logging.info('Meta-train batch config : [{}, {}]'.format(
            self.batch_size, self.image_size_batch))
Esempio n. 3
0
    def build_batch(self, split):
        """Builds a Batch object containing the next data for "split".

    Args:
      split: A string, either 'train', 'valid', or 'test'.

    Returns:
      An EpisodeDataset.
    """
        shuffle_buffer_size = self.data_config.shuffle_buffer_size
        read_buffer_size_bytes = self.data_config.read_buffer_size_bytes
        _, image_shape, dataset_spec_list, _, _ = self.benchmark_spec
        dataset_split, batch_size = self.split_episode_or_batch_specs[split]
        for dataset_spec in dataset_spec_list:
            if dataset_spec.name in DATASETS_WITH_EXAMPLE_SPLITS:
                raise ValueError(
                    'Batch pipeline is used only at meta-train time, and does not '
                    'handle datasets with example splits, which should only be used '
                    'at meta-test (evaluation) time.')
        # TODO(lamblinp): pass specs directly to the pipeline builder.
        # TODO(lamblinp): move the special case directly in make_..._pipeline
        if len(dataset_spec_list) == 1:
            data_pipeline = pipeline.make_one_source_batch_pipeline(
                dataset_spec_list[0],
                split=dataset_split,
                batch_size=batch_size,
                shuffle_buffer_size=shuffle_buffer_size,
                read_buffer_size_bytes=read_buffer_size_bytes,
                image_size=image_shape[0])
        else:
            data_pipeline = pipeline.make_multisource_batch_pipeline(
                dataset_spec_list,
                split=dataset_split,
                batch_size=batch_size,
                shuffle_buffer_size=shuffle_buffer_size,
                read_buffer_size_bytes=read_buffer_size_bytes,
                image_size=image_shape[0])

        data_pipeline = apply_dataset_options(data_pipeline)
        iterator = data_pipeline.make_one_shot_iterator()
        images, class_ids = iterator.get_next()
        return providers.Batch(images=images, labels=class_ids)