示例#1
0
  def test_meta_dataset(self):
    gin.parse_config_file(tfds.core.as_path(meta_dataset.__file__).parent /
                          'learn/gin/setups/data_config_tfds.gin')
    gin.parse_config(_DETERMINISTIC_CONFIG)
    data_config = config_lib.DataConfig()
    seed = 20210917
    num_episodes = 10
    meta_split = 'valid'
    md_sources = ('aircraft', 'cu_birds', 'vgg_flower')

    sampling.RNG = np.random.RandomState(seed)
    tfds_episode_dataset = api.meta_dataset(
        md_sources=md_sources, md_version='v1', meta_split=meta_split,
        source_sampling_seed=seed + 1, data_dir=FLAGS.tfds_path)
    tfds_episodes = list(
        tfds_episode_dataset.take(num_episodes).as_numpy_iterator())

    sampling.RNG = np.random.RandomState(seed)
    dataset_spec_list = [
        dataset_spec_lib.load_dataset_spec(os.path.join(FLAGS.meta_dataset_path,
                                                        md_source))
        for md_source in md_sources
    ]
    # We should not skip TFExample decoding in the original Meta-Dataset
    # implementation. The kwarg defaults to False when the class is defined, but
    # the call to `gin.parse_config_file` above changes the default value to
    # True, which is why we have to explicitly bind a new default value here.
    gin.bind_parameter('ImageDecoder.skip_tfexample_decoding', False)
    md_episode_dataset = pipeline.make_multisource_episode_pipeline(
        dataset_spec_list=dataset_spec_list,
        use_dag_ontology_list=['ilsvrc_2012' in dataset_spec.name
                               for dataset_spec in dataset_spec_list],
        use_bilevel_ontology_list=[dataset_spec.name == 'omniglot'
                                   for dataset_spec in dataset_spec_list],
        split=getattr(learning_spec.Split, meta_split.upper()),
        episode_descr_config=config_lib.EpisodeDescriptionConfig(),
        pool=None,
        shuffle_buffer_size=data_config.shuffle_buffer_size,
        image_size=data_config.image_height,
        source_sampling_seed=seed + 1
    )
    md_episodes = list(
        md_episode_dataset.take(num_episodes).as_numpy_iterator())

    for (tfds_episode, tfds_source_id), (md_episode, md_source_id) in zip(
        tfds_episodes, md_episodes):

      np.testing.assert_equal(tfds_source_id, md_source_id)

      for tfds_tensor, md_tensor in zip(tfds_episode, md_episode):
        np.testing.assert_allclose(tfds_tensor, md_tensor)
示例#2
0
  def test_episodic_overfit(self,
                            learner_class,
                            learner_config,
                            threshold=1.,
                            attempts=1):
    """Test that error goes down when training on a single episode.

    This can help check that the trained model and the evaluated one share
    the trainable parameters correctly.

    Args:
      learner_class: A subclass of Learner.
      learner_config: A string, the Learner-specific gin configuration string.
      threshold: A float (default 1.), the performance to reach at least once.
      attempts: An int (default 1), how many of the last steps should be checked
        when looking for a validation value reaching the threshold (default 1).
    """
    gin_config = '\n'.join((self.BASE_GIN_CONFIG, learner_config))
    gin.parse_config(gin_config)

    episode_config = config.EpisodeDescriptionConfig(
        num_ways=self.NUM_EXAMPLES, num_support=1, num_query=1)

    trainer_instance = trainer.Trainer(
        train_learner_class=learner_class,
        eval_learner_class=learner_class,
        is_training=True,
        train_dataset_list=['dummy'],
        eval_dataset_list=['dummy'],
        records_root_dir=self.temp_dir,
        checkpoint_dir=os.path.join(self.temp_dir, 'checkpoints'),
        train_episode_config=episode_config,
        eval_episode_config=episode_config,
        data_config=config.DataConfig(),
        # BEGIN GOOGLE_INTERNAL
        real_episodes=False,
        real_episodes_results_dir='',
        # END GOOGLE_INTERNAL
    )
    # Train 1 update at a time for the last `attempts - 1` steps.
    trainer_instance.num_updates -= (attempts - 1)
    trainer_instance.train()
    valid_accs = [trainer_instance.valid_acc]
    for _ in range(attempts - 1):
      trainer_instance.num_updates += 1
      trainer_instance.train()
      valid_accs.append(trainer_instance.valid_acc)
    self.assertGreaterEqual(max(valid_accs), threshold)
示例#3
0
def main(unused_argv):
    logging.info(FLAGS.output_dir)
    tf.io.gfile.makedirs(FLAGS.output_dir)
    gin.parse_config_files_and_bindings(FLAGS.gin_config,
                                        FLAGS.gin_bindings,
                                        finalize_config=True)
    dataset_spec = dataset_spec_lib.load_dataset_spec(
        os.path.join(FLAGS.records_root_dir, FLAGS.dataset_name))
    data_config = config.DataConfig()
    episode_descr_config = config.EpisodeDescriptionConfig()
    use_dag_ontology = (FLAGS.dataset_name in ('ilsvrc_2012', 'ilsvrc_2012_v2')
                        and not FLAGS.ignore_dag_ontology)
    use_bilevel_ontology = (FLAGS.dataset_name == 'omniglot'
                            and not FLAGS.ignore_bilevel_ontology)
    data_pipeline = pipeline.make_one_source_episode_pipeline(
        dataset_spec,
        use_dag_ontology=use_dag_ontology,
        use_bilevel_ontology=use_bilevel_ontology,
        split=FLAGS.split,
        episode_descr_config=episode_descr_config,
        # TODO(evcu) Maybe set the following to 0 to prevent shuffling and check
        # reproducibility of dumping.
        shuffle_buffer_size=data_config.shuffle_buffer_size,
        read_buffer_size_bytes=data_config.read_buffer_size_bytes,
        num_prefetch=data_config.num_prefetch)
    dataset = data_pipeline.take(FLAGS.num_episodes)

    images_per_class_dict = {}
    # Ignoring dataset number since we are loading one dataset.
    for episode_number, (episode, _) in enumerate(dataset):
        logging.info('Dumping episode %d', episode_number)
        train_imgs, train_labels, _, test_imgs, test_labels, _ = episode
        path_train = utils.get_file_path(FLAGS.output_dir, episode_number,
                                         'train')
        path_test = utils.get_file_path(FLAGS.output_dir, episode_number,
                                        'test')
        utils.dump_as_tfrecord(path_train, train_imgs, train_labels)
        utils.dump_as_tfrecord(path_test, test_imgs, test_labels)
        images_per_class_dict[os.path.basename(path_train)] = (
            utils.get_label_counts(train_labels))
        images_per_class_dict[os.path.basename(path_test)] = (
            utils.get_label_counts(test_labels))
    info_path = utils.get_info_path(FLAGS.output_dir)
    with tf.io.gfile.GFile(info_path, 'w') as f:
        f.write(json.dumps(images_per_class_dict, indent=2))
示例#4
0
  def test_episodic_trainer(self):
    # Inspired from `learn/gin/default/debug_proto_mini_imagenet.gin`, but
    # building the objects explicitly.
    learn_config = trainer.LearnConfig(
        num_updates=100,
        batch_size=8,  # unused
        num_eval_episodes=10,
        checkpoint_every=10,
        validate_every=5,
        log_every=1,
        transductive_batch_norm=False,
    )

    learner_config = trainer.LearnerConfig(
        episodic=True,
        train_learner='PrototypicalNet',
        eval_learner='PrototypicalNet',
        pretrained_checkpoint='',
        checkpoint_for_eval='',
        embedding_network='four_layer_convnet',
        learning_rate=1e-4,
        decay_learning_rate=True,
        decay_every=5000,
        decay_rate=0.5,
        experiment_name='test',
        pretrained_source='',
    )

    # PrototypicalNetworkLearner is built automatically and this test does not
    # have the opportunity to pass values to its constructor except through gin.
    gin.bind_parameter('PrototypicalNetworkLearner.weight_decay', 1e-4)

    # Values for EpisodeDescriptionSampler
    gin.bind_parameter('EpisodeDescriptionSampler.min_ways', 5)
    gin.bind_parameter('EpisodeDescriptionSampler.max_ways_upper_bound', 50)
    gin.bind_parameter('EpisodeDescriptionSampler.max_num_query', 10)
    gin.bind_parameter('EpisodeDescriptionSampler.max_support_set_size', 500)
    gin.bind_parameter(
        'EpisodeDescriptionSampler.max_support_size_contrib_per_class', 100)
    gin.bind_parameter('EpisodeDescriptionSampler.min_log_weight',
                       -0.69314718055994529)  # np.log(0.5)
    gin.bind_parameter('EpisodeDescriptionSampler.max_log_weight',
                       0.69314718055994529)  # np.log(2)

    data_config = config.DataConfig(
        image_height=84,
        shuffle_buffer_size=20,
        read_buffer_size_bytes=(1024**2),
        num_prefetch=2,
    )

    episodic_trainer = trainer.EpisodicTrainer(
        train_learner=learner.PrototypicalNetworkLearner,
        eval_learner=learner.PrototypicalNetworkLearner,
        is_training=True,
        dataset_list=['mini_imagenet'],
        checkpoint_dir='',
        summary_dir='',
        eval_finegrainedness=False,
        eval_finegrainedness_split='',
        eval_imbalance_dataset='',
        num_train_classes=None,
        num_test_classes=None,
        num_train_examples=None,
        num_test_examples=None,
        learn_config=learn_config,
        learner_config=learner_config,
        data_config=data_config,
    )

    # Get the next train / valid / test episodes.
    train_episode = episodic_trainer.get_next('train')
    self.assertIsInstance(train_episode, providers.EpisodeDataset)

    # This isn't really a test. It just checks that things don't crash...
    print(
        episodic_trainer.sess.run([
            episodic_trainer.train_op, episodic_trainer.losses['train'],
            episodic_trainer.accs['train']
        ]))
示例#5
0
    def test_trainer(self):
        # PrototypicalNetworkLearner is built automatically and this test does not
        # have the opportunity to pass values to its constructor except through gin.
        gin.bind_parameter('PrototypicalNetworkLearner.weight_decay', 1e-4)
        gin.bind_parameter(
            'PrototypicalNetworkLearner.backprop_through_moments', True)
        gin.bind_parameter(
            'PrototypicalNetworkLearner.transductive_batch_norm', False)
        gin.bind_parameter('PrototypicalNetworkLearner.embedding_fn',
                           'four_layer_convnet')

        # Values that can't be passed directly to EpisodeDescriptionConfig
        gin.bind_parameter('process_episode.support_decoder',
                           decoder.ImageDecoder())
        gin.bind_parameter('process_episode.query_decoder',
                           decoder.ImageDecoder())

        episode_config = config.EpisodeDescriptionConfig(
            num_ways=None,
            num_support=None,
            num_query=None,
            min_ways=5,
            max_ways_upper_bound=50,
            max_num_query=10,
            max_support_set_size=500,
            max_support_size_contrib_per_class=100,
            min_log_weight=np.log(0.5),
            max_log_weight=np.log(2),
            ignore_dag_ontology=False,
            ignore_bilevel_ontology=False)

        # Inspired from `learn/gin/default/debug_proto_mini_imagenet.gin`, but
        # building the objects explicitly.
        data_config = config.DataConfig(
            image_height=84,
            shuffle_buffer_size=20,
            read_buffer_size_bytes=(1024**2),
            num_prefetch=2,
        )

        trainer_instance = trainer.Trainer(
            train_learner_class=learner_lib.PrototypicalNetworkLearner,
            eval_learner_class=learner_lib.PrototypicalNetworkLearner,
            is_training=True,
            train_dataset_list=['mini_imagenet'],
            eval_dataset_list=['mini_imagenet'],
            restrict_classes={},
            restrict_num_per_class={},
            checkpoint_dir='',
            summary_dir='',
            records_root_dir=FLAGS.records_root_dir,
            eval_split=trainer.VALID_SPLIT,
            eval_finegrainedness=False,
            eval_finegrainedness_split='',
            eval_imbalance_dataset='',
            omit_from_saving_and_reloading='',
            train_episode_config=episode_config,
            eval_episode_config=episode_config,
            data_config=data_config,
            num_updates=100,
            batch_size=8,  # unused
            num_eval_episodes=10,
            checkpoint_every=10,
            validate_every=5,
            log_every=1,
            checkpoint_to_restore=None,
            learning_rate=1e-4,
            decay_learning_rate=True,
            decay_every=5000,
            decay_rate=0.5,
            experiment_name='test',
            pretrained_source='',
        )

        # Get the next train / valid / test episodes.
        train_episode = trainer_instance.next_data[trainer.TRAIN_SPLIT]
        self.assertIsInstance(train_episode, providers.EpisodeDataset)

        # This isn't really a test. It just checks that things don't crash...
        print(
            trainer_instance.sess.run([
                trainer_instance.train_op,
                trainer_instance.losses[trainer.TRAIN_SPLIT],
                trainer_instance.accuracies[trainer.TRAIN_SPLIT]
            ]))
示例#6
0
  def test_episode_dataset_matches_meta_dataset(self,
                                                source,
                                                md_source,
                                                md_version,
                                                meta_split,
                                                remap_labels):

    gin.parse_config_file(tfds.core.as_path(meta_dataset.__file__).parent /
                          'learn/gin/setups/data_config_tfds.gin')
    gin.parse_config(_DETERMINISTIC_CONFIG)
    data_config = config_lib.DataConfig()
    seed = 20210917
    num_episodes = 10

    builder = tfds.builder(
        'meta_dataset', config=source, data_dir=FLAGS.tfds_path)
    sampling.RNG = np.random.RandomState(seed)
    tfds_episode_dataset = api.episode_dataset(
        builder, md_version, meta_split, source_id=0)
    # For MD-v2's ilsvrc_2012 source, train classes are sorted by class name,
    # whereas the TFDS implementation intentionally keeps the v1 class order.
    if remap_labels:
      # The first argsort tells us where to look in class_names for position i
      # in the sorted class list. The second argsort reverses that: it tells us,
      # for position j in class_names, where to place that class in the sorted
      # class list.
      label_map = np.argsort(np.argsort(builder.info.metadata['class_names']))
      label_remap_fn = np.vectorize(lambda x: label_map[x])
    tfds_episodes = list(
        tfds_episode_dataset.take(num_episodes).as_numpy_iterator())

    dataset_spec = dataset_spec_lib.load_dataset_spec(
        os.path.join(FLAGS.meta_dataset_path, md_source))
    sampling.RNG = np.random.RandomState(seed)
    # We should not skip TFExample decoding in the original Meta-Dataset
    # implementation.
    gin.bind_parameter('ImageDecoder.skip_tfexample_decoding', False)
    md_episode_dataset = pipeline.make_one_source_episode_pipeline(
        dataset_spec,
        use_dag_ontology='ilsvrc_2012' in dataset_spec.name,
        use_bilevel_ontology=dataset_spec.name == 'omniglot',
        split=getattr(learning_spec.Split, meta_split.upper()),
        episode_descr_config=config_lib.EpisodeDescriptionConfig(),
        pool=None,
        shuffle_buffer_size=data_config.shuffle_buffer_size,
        image_size=data_config.image_height
    )
    md_episodes = list(
        md_episode_dataset.take(num_episodes).as_numpy_iterator())

    for (tfds_episode, tfds_source_id), (md_episode, md_source_id) in zip(
        tfds_episodes, md_episodes):

      np.testing.assert_equal(tfds_source_id, md_source_id)

      if remap_labels:
        tfds_episode = list(tfds_episode)
        tfds_episode[2] = label_remap_fn(tfds_episode[2])
        tfds_episode[5] = label_remap_fn(tfds_episode[5])
        tfds_episode = tuple(tfds_episode)

      for tfds_tensor, md_tensor in zip(tfds_episode, md_episode):
        np.testing.assert_allclose(tfds_tensor, md_tensor)