def test_meta_dataset(self): gin.parse_config_file(tfds.core.as_path(meta_dataset.__file__).parent / 'learn/gin/setups/data_config_tfds.gin') gin.parse_config(_DETERMINISTIC_CONFIG) data_config = config_lib.DataConfig() seed = 20210917 num_episodes = 10 meta_split = 'valid' md_sources = ('aircraft', 'cu_birds', 'vgg_flower') sampling.RNG = np.random.RandomState(seed) tfds_episode_dataset = api.meta_dataset( md_sources=md_sources, md_version='v1', meta_split=meta_split, source_sampling_seed=seed + 1, data_dir=FLAGS.tfds_path) tfds_episodes = list( tfds_episode_dataset.take(num_episodes).as_numpy_iterator()) sampling.RNG = np.random.RandomState(seed) dataset_spec_list = [ dataset_spec_lib.load_dataset_spec(os.path.join(FLAGS.meta_dataset_path, md_source)) for md_source in md_sources ] # We should not skip TFExample decoding in the original Meta-Dataset # implementation. The kwarg defaults to False when the class is defined, but # the call to `gin.parse_config_file` above changes the default value to # True, which is why we have to explicitly bind a new default value here. gin.bind_parameter('ImageDecoder.skip_tfexample_decoding', False) md_episode_dataset = pipeline.make_multisource_episode_pipeline( dataset_spec_list=dataset_spec_list, use_dag_ontology_list=['ilsvrc_2012' in dataset_spec.name for dataset_spec in dataset_spec_list], use_bilevel_ontology_list=[dataset_spec.name == 'omniglot' for dataset_spec in dataset_spec_list], split=getattr(learning_spec.Split, meta_split.upper()), episode_descr_config=config_lib.EpisodeDescriptionConfig(), pool=None, shuffle_buffer_size=data_config.shuffle_buffer_size, image_size=data_config.image_height, source_sampling_seed=seed + 1 ) md_episodes = list( md_episode_dataset.take(num_episodes).as_numpy_iterator()) for (tfds_episode, tfds_source_id), (md_episode, md_source_id) in zip( tfds_episodes, md_episodes): np.testing.assert_equal(tfds_source_id, md_source_id) for tfds_tensor, md_tensor in zip(tfds_episode, md_episode): np.testing.assert_allclose(tfds_tensor, md_tensor)
def test_episodic_overfit(self, learner_class, learner_config, threshold=1., attempts=1): """Test that error goes down when training on a single episode. This can help check that the trained model and the evaluated one share the trainable parameters correctly. Args: learner_class: A subclass of Learner. learner_config: A string, the Learner-specific gin configuration string. threshold: A float (default 1.), the performance to reach at least once. attempts: An int (default 1), how many of the last steps should be checked when looking for a validation value reaching the threshold (default 1). """ gin_config = '\n'.join((self.BASE_GIN_CONFIG, learner_config)) gin.parse_config(gin_config) episode_config = config.EpisodeDescriptionConfig( num_ways=self.NUM_EXAMPLES, num_support=1, num_query=1) trainer_instance = trainer.Trainer( train_learner_class=learner_class, eval_learner_class=learner_class, is_training=True, train_dataset_list=['dummy'], eval_dataset_list=['dummy'], records_root_dir=self.temp_dir, checkpoint_dir=os.path.join(self.temp_dir, 'checkpoints'), train_episode_config=episode_config, eval_episode_config=episode_config, data_config=config.DataConfig(), # BEGIN GOOGLE_INTERNAL real_episodes=False, real_episodes_results_dir='', # END GOOGLE_INTERNAL ) # Train 1 update at a time for the last `attempts - 1` steps. trainer_instance.num_updates -= (attempts - 1) trainer_instance.train() valid_accs = [trainer_instance.valid_acc] for _ in range(attempts - 1): trainer_instance.num_updates += 1 trainer_instance.train() valid_accs.append(trainer_instance.valid_acc) self.assertGreaterEqual(max(valid_accs), threshold)
def main(unused_argv): logging.info(FLAGS.output_dir) tf.io.gfile.makedirs(FLAGS.output_dir) gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings, finalize_config=True) dataset_spec = dataset_spec_lib.load_dataset_spec( os.path.join(FLAGS.records_root_dir, FLAGS.dataset_name)) data_config = config.DataConfig() episode_descr_config = config.EpisodeDescriptionConfig() use_dag_ontology = (FLAGS.dataset_name in ('ilsvrc_2012', 'ilsvrc_2012_v2') and not FLAGS.ignore_dag_ontology) use_bilevel_ontology = (FLAGS.dataset_name == 'omniglot' and not FLAGS.ignore_bilevel_ontology) data_pipeline = pipeline.make_one_source_episode_pipeline( dataset_spec, use_dag_ontology=use_dag_ontology, use_bilevel_ontology=use_bilevel_ontology, split=FLAGS.split, episode_descr_config=episode_descr_config, # TODO(evcu) Maybe set the following to 0 to prevent shuffling and check # reproducibility of dumping. shuffle_buffer_size=data_config.shuffle_buffer_size, read_buffer_size_bytes=data_config.read_buffer_size_bytes, num_prefetch=data_config.num_prefetch) dataset = data_pipeline.take(FLAGS.num_episodes) images_per_class_dict = {} # Ignoring dataset number since we are loading one dataset. for episode_number, (episode, _) in enumerate(dataset): logging.info('Dumping episode %d', episode_number) train_imgs, train_labels, _, test_imgs, test_labels, _ = episode path_train = utils.get_file_path(FLAGS.output_dir, episode_number, 'train') path_test = utils.get_file_path(FLAGS.output_dir, episode_number, 'test') utils.dump_as_tfrecord(path_train, train_imgs, train_labels) utils.dump_as_tfrecord(path_test, test_imgs, test_labels) images_per_class_dict[os.path.basename(path_train)] = ( utils.get_label_counts(train_labels)) images_per_class_dict[os.path.basename(path_test)] = ( utils.get_label_counts(test_labels)) info_path = utils.get_info_path(FLAGS.output_dir) with tf.io.gfile.GFile(info_path, 'w') as f: f.write(json.dumps(images_per_class_dict, indent=2))
def test_episodic_trainer(self): # Inspired from `learn/gin/default/debug_proto_mini_imagenet.gin`, but # building the objects explicitly. learn_config = trainer.LearnConfig( num_updates=100, batch_size=8, # unused num_eval_episodes=10, checkpoint_every=10, validate_every=5, log_every=1, transductive_batch_norm=False, ) learner_config = trainer.LearnerConfig( episodic=True, train_learner='PrototypicalNet', eval_learner='PrototypicalNet', pretrained_checkpoint='', checkpoint_for_eval='', embedding_network='four_layer_convnet', learning_rate=1e-4, decay_learning_rate=True, decay_every=5000, decay_rate=0.5, experiment_name='test', pretrained_source='', ) # PrototypicalNetworkLearner is built automatically and this test does not # have the opportunity to pass values to its constructor except through gin. gin.bind_parameter('PrototypicalNetworkLearner.weight_decay', 1e-4) # Values for EpisodeDescriptionSampler gin.bind_parameter('EpisodeDescriptionSampler.min_ways', 5) gin.bind_parameter('EpisodeDescriptionSampler.max_ways_upper_bound', 50) gin.bind_parameter('EpisodeDescriptionSampler.max_num_query', 10) gin.bind_parameter('EpisodeDescriptionSampler.max_support_set_size', 500) gin.bind_parameter( 'EpisodeDescriptionSampler.max_support_size_contrib_per_class', 100) gin.bind_parameter('EpisodeDescriptionSampler.min_log_weight', -0.69314718055994529) # np.log(0.5) gin.bind_parameter('EpisodeDescriptionSampler.max_log_weight', 0.69314718055994529) # np.log(2) data_config = config.DataConfig( image_height=84, shuffle_buffer_size=20, read_buffer_size_bytes=(1024**2), num_prefetch=2, ) episodic_trainer = trainer.EpisodicTrainer( train_learner=learner.PrototypicalNetworkLearner, eval_learner=learner.PrototypicalNetworkLearner, is_training=True, dataset_list=['mini_imagenet'], checkpoint_dir='', summary_dir='', eval_finegrainedness=False, eval_finegrainedness_split='', eval_imbalance_dataset='', num_train_classes=None, num_test_classes=None, num_train_examples=None, num_test_examples=None, learn_config=learn_config, learner_config=learner_config, data_config=data_config, ) # Get the next train / valid / test episodes. train_episode = episodic_trainer.get_next('train') self.assertIsInstance(train_episode, providers.EpisodeDataset) # This isn't really a test. It just checks that things don't crash... print( episodic_trainer.sess.run([ episodic_trainer.train_op, episodic_trainer.losses['train'], episodic_trainer.accs['train'] ]))
def test_trainer(self): # PrototypicalNetworkLearner is built automatically and this test does not # have the opportunity to pass values to its constructor except through gin. gin.bind_parameter('PrototypicalNetworkLearner.weight_decay', 1e-4) gin.bind_parameter( 'PrototypicalNetworkLearner.backprop_through_moments', True) gin.bind_parameter( 'PrototypicalNetworkLearner.transductive_batch_norm', False) gin.bind_parameter('PrototypicalNetworkLearner.embedding_fn', 'four_layer_convnet') # Values that can't be passed directly to EpisodeDescriptionConfig gin.bind_parameter('process_episode.support_decoder', decoder.ImageDecoder()) gin.bind_parameter('process_episode.query_decoder', decoder.ImageDecoder()) episode_config = config.EpisodeDescriptionConfig( num_ways=None, num_support=None, num_query=None, min_ways=5, max_ways_upper_bound=50, max_num_query=10, max_support_set_size=500, max_support_size_contrib_per_class=100, min_log_weight=np.log(0.5), max_log_weight=np.log(2), ignore_dag_ontology=False, ignore_bilevel_ontology=False) # Inspired from `learn/gin/default/debug_proto_mini_imagenet.gin`, but # building the objects explicitly. data_config = config.DataConfig( image_height=84, shuffle_buffer_size=20, read_buffer_size_bytes=(1024**2), num_prefetch=2, ) trainer_instance = trainer.Trainer( train_learner_class=learner_lib.PrototypicalNetworkLearner, eval_learner_class=learner_lib.PrototypicalNetworkLearner, is_training=True, train_dataset_list=['mini_imagenet'], eval_dataset_list=['mini_imagenet'], restrict_classes={}, restrict_num_per_class={}, checkpoint_dir='', summary_dir='', records_root_dir=FLAGS.records_root_dir, eval_split=trainer.VALID_SPLIT, eval_finegrainedness=False, eval_finegrainedness_split='', eval_imbalance_dataset='', omit_from_saving_and_reloading='', train_episode_config=episode_config, eval_episode_config=episode_config, data_config=data_config, num_updates=100, batch_size=8, # unused num_eval_episodes=10, checkpoint_every=10, validate_every=5, log_every=1, checkpoint_to_restore=None, learning_rate=1e-4, decay_learning_rate=True, decay_every=5000, decay_rate=0.5, experiment_name='test', pretrained_source='', ) # Get the next train / valid / test episodes. train_episode = trainer_instance.next_data[trainer.TRAIN_SPLIT] self.assertIsInstance(train_episode, providers.EpisodeDataset) # This isn't really a test. It just checks that things don't crash... print( trainer_instance.sess.run([ trainer_instance.train_op, trainer_instance.losses[trainer.TRAIN_SPLIT], trainer_instance.accuracies[trainer.TRAIN_SPLIT] ]))
def test_episode_dataset_matches_meta_dataset(self, source, md_source, md_version, meta_split, remap_labels): gin.parse_config_file(tfds.core.as_path(meta_dataset.__file__).parent / 'learn/gin/setups/data_config_tfds.gin') gin.parse_config(_DETERMINISTIC_CONFIG) data_config = config_lib.DataConfig() seed = 20210917 num_episodes = 10 builder = tfds.builder( 'meta_dataset', config=source, data_dir=FLAGS.tfds_path) sampling.RNG = np.random.RandomState(seed) tfds_episode_dataset = api.episode_dataset( builder, md_version, meta_split, source_id=0) # For MD-v2's ilsvrc_2012 source, train classes are sorted by class name, # whereas the TFDS implementation intentionally keeps the v1 class order. if remap_labels: # The first argsort tells us where to look in class_names for position i # in the sorted class list. The second argsort reverses that: it tells us, # for position j in class_names, where to place that class in the sorted # class list. label_map = np.argsort(np.argsort(builder.info.metadata['class_names'])) label_remap_fn = np.vectorize(lambda x: label_map[x]) tfds_episodes = list( tfds_episode_dataset.take(num_episodes).as_numpy_iterator()) dataset_spec = dataset_spec_lib.load_dataset_spec( os.path.join(FLAGS.meta_dataset_path, md_source)) sampling.RNG = np.random.RandomState(seed) # We should not skip TFExample decoding in the original Meta-Dataset # implementation. gin.bind_parameter('ImageDecoder.skip_tfexample_decoding', False) md_episode_dataset = pipeline.make_one_source_episode_pipeline( dataset_spec, use_dag_ontology='ilsvrc_2012' in dataset_spec.name, use_bilevel_ontology=dataset_spec.name == 'omniglot', split=getattr(learning_spec.Split, meta_split.upper()), episode_descr_config=config_lib.EpisodeDescriptionConfig(), pool=None, shuffle_buffer_size=data_config.shuffle_buffer_size, image_size=data_config.image_height ) md_episodes = list( md_episode_dataset.take(num_episodes).as_numpy_iterator()) for (tfds_episode, tfds_source_id), (md_episode, md_source_id) in zip( tfds_episodes, md_episodes): np.testing.assert_equal(tfds_source_id, md_source_id) if remap_labels: tfds_episode = list(tfds_episode) tfds_episode[2] = label_remap_fn(tfds_episode[2]) tfds_episode[5] = label_remap_fn(tfds_episode[5]) tfds_episode = tuple(tfds_episode) for tfds_tensor, md_tensor in zip(tfds_episode, md_episode): np.testing.assert_allclose(tfds_tensor, md_tensor)