def main(unused_argv): tf.logging.info('FLAGS.gin_config: %s', FLAGS.gin_config) tf.logging.info('FLAGS.gin_bindings: %s', FLAGS.gin_bindings) gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) learner_config = trainer.LearnerConfig() # Check for inconsistent or contradictory flag settings. if (learner_config.checkpoint_for_eval and learner_config.pretrained_checkpoint): raise ValueError( 'Can not define both checkpoint_for_eval and ' 'pretrained_checkpoint. The difference between them is ' 'that in the former all variables are restored (including ' 'global step) whereas the latter is only applicable to ' 'the start of training for initializing the model from ' 'pre-trained weights. It is also only applicable to ' 'episodic models and restores only the embedding weights.') (train_datasets, eval_datasets, restrict_num_per_class) = trainer.get_datasets_and_restrictions() train_learner = None if FLAGS.is_training or (FLAGS.eval_finegrainedness and FLAGS.eval_finegrainedness_split == 'train'): # If eval_finegrainedness is True, even in pure evaluation mode we still # require a train learner, since we may perform this analysis on the # training sub-graph of ImageNet too. train_learner = NAME_TO_LEARNER[learner_config.train_learner] eval_learner = NAME_TO_LEARNER[learner_config.eval_learner] # Get a trainer or evaluator. trainer_kwargs = { 'train_learner': train_learner, 'eval_learner': eval_learner, 'is_training': FLAGS.is_training, 'train_dataset_list': train_datasets, 'eval_dataset_list': eval_datasets, 'restrict_num_per_class': restrict_num_per_class, 'checkpoint_dir': FLAGS.train_checkpoint_dir, 'summary_dir': FLAGS.summary_dir, 'records_root_dir': FLAGS.records_root_dir, 'eval_finegrainedness': FLAGS.eval_finegrainedness, 'eval_finegrainedness_split': FLAGS.eval_finegrainedness_split, 'eval_imbalance_dataset': FLAGS.eval_imbalance_dataset, 'omit_from_saving_and_reloading': FLAGS.omit_from_saving_and_reloading, } if learner_config.episodic: trainer_instance = trainer.EpisodicTrainer(**trainer_kwargs) if learner_config.train_learner not in EPISODIC_LEARNERS: raise ValueError( 'When "episodic" is True, "train_learner" should be an episodic one, ' 'among {}.'.format(EPISODIC_LEARNERS)) else: trainer_instance = trainer.BatchTrainer(**trainer_kwargs) if learner_config.train_learner not in BATCH_LEARNERS: raise ValueError( 'When "episodic" is False, "train_learner" should be a batch one, ' 'among {}.'.format(BATCH_LEARNERS)) mode = 'training' if FLAGS.is_training else 'evaluation' datasets = train_datasets if FLAGS.is_training else eval_datasets tf.logging.info('Starting %s for dataset(s) %s...' % (mode, datasets)) # Record gin operative config string after the setup, both in the logs and in # the checkpoint directory. gin_operative_config = gin.operative_config_str() tf.logging.info('gin configuration:\n%s', gin_operative_config) if FLAGS.train_checkpoint_dir: gin_log_file = os.path.join(FLAGS.train_checkpoint_dir, 'operative_config.gin') # If it exists already, rename it instead of overwriting it. # This just saves the previous one, not all the ones before. if tf.gfile.Exists(gin_log_file): tf.gfile.Rename(gin_log_file, gin_log_file + '.old', overwrite=True) with tf.io.gfile.GFile(gin_log_file, 'w') as f: f.write(gin_operative_config) if FLAGS.is_training: trainer_instance.train() elif set(datasets).intersection(trainer.DATASETS_WITH_EXAMPLE_SPLITS): if not data.POOL_SUPPORTED: raise NotImplementedError( 'Example-level splits or pools not supported.') else: if len(datasets) != 1: raise ValueError( 'Requested datasets {} for evaluation, but evaluation ' 'should be performed on individual datasets ' 'only.'.format(datasets)) eval_split = 'test' if FLAGS.eval_finegrainedness: eval_split = FLAGS.eval_finegrainedness_split trainer_instance.evaluate(eval_split) # Flushes the event file to disk and closes the file. if trainer_instance.summary_writer: trainer_instance.summary_writer.close()
def test_episodic_trainer(self): # Inspired from `learn/gin/default/debug_proto_mini_imagenet.gin`, but # building the objects explicitly. learn_config = trainer.LearnConfig( num_updates=100, batch_size=8, # unused num_eval_episodes=10, checkpoint_every=10, validate_every=5, log_every=1, transductive_batch_norm=False, ) learner_config = trainer.LearnerConfig( episodic=True, train_learner='PrototypicalNet', eval_learner='PrototypicalNet', pretrained_checkpoint='', checkpoint_for_eval='', embedding_network='four_layer_convnet', learning_rate=1e-4, decay_learning_rate=True, decay_every=5000, decay_rate=0.5, experiment_name='test', pretrained_source='', ) # PrototypicalNetworkLearner is built automatically and this test does not # have the opportunity to pass values to its constructor except through gin. gin.bind_parameter('PrototypicalNetworkLearner.weight_decay', 1e-4) # Values for EpisodeDescriptionSampler gin.bind_parameter('EpisodeDescriptionSampler.min_ways', 5) gin.bind_parameter('EpisodeDescriptionSampler.max_ways_upper_bound', 50) gin.bind_parameter('EpisodeDescriptionSampler.max_num_query', 10) gin.bind_parameter('EpisodeDescriptionSampler.max_support_set_size', 500) gin.bind_parameter( 'EpisodeDescriptionSampler.max_support_size_contrib_per_class', 100) gin.bind_parameter('EpisodeDescriptionSampler.min_log_weight', -0.69314718055994529) # np.log(0.5) gin.bind_parameter('EpisodeDescriptionSampler.max_log_weight', 0.69314718055994529) # np.log(2) data_config = config.DataConfig( image_height=84, shuffle_buffer_size=20, read_buffer_size_bytes=(1024**2), num_prefetch=2, ) episodic_trainer = trainer.EpisodicTrainer( train_learner=learner.PrototypicalNetworkLearner, eval_learner=learner.PrototypicalNetworkLearner, is_training=True, dataset_list=['mini_imagenet'], checkpoint_dir='', summary_dir='', eval_finegrainedness=False, eval_finegrainedness_split='', eval_imbalance_dataset='', num_train_classes=None, num_test_classes=None, num_train_examples=None, num_test_examples=None, learn_config=learn_config, learner_config=learner_config, data_config=data_config, ) # Get the next train / valid / test episodes. train_episode = episodic_trainer.get_next('train') self.assertIsInstance(train_episode, providers.EpisodeDataset) # This isn't really a test. It just checks that things don't crash... print( episodic_trainer.sess.run([ episodic_trainer.train_op, episodic_trainer.losses['train'], episodic_trainer.accs['train'] ]))
def main(unused_argv): gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) learner_config = trainer.LearnerConfig() # Check for inconsistent or contradictory flag settings. if (learner_config.checkpoint_for_eval and learner_config.pretrained_checkpoint): raise ValueError( 'Can not define both checkpoint_for_eval and ' 'pretrained_checkpoint. The difference between them is ' 'that in the former all variables are restored (including ' 'global step) whereas the latter is only applicable to ' 'the start of training for initializing the model from ' 'pre-trained weights. It is also only applicable to ' 'episodic models and restores only the embedding weights.') datasets = get_datasets() train_learner = None if FLAGS.is_training or (FLAGS.eval_finegrainedness and FLAGS.eval_finegrainedness_split == 'train'): # If eval_finegrainedness is True, even in pure evaluation mode we still # require a train learner, since we may perform this analysis on the # training sub-graph of ImageNet too. train_learner = NAME_TO_LEARNER[learner_config.train_learner] eval_learner = NAME_TO_LEARNER[learner_config.eval_learner] # Get a trainer or evaluator. if learner_config.episodic: trainer_instance = trainer.EpisodicTrainer( train_learner, eval_learner, FLAGS.is_training, datasets, FLAGS.train_checkpoint_dir, FLAGS.summary_dir, FLAGS.eval_finegrainedness, FLAGS.eval_finegrainedness_split, FLAGS.eval_imbalance_dataset) if learner_config.train_learner not in EPISODIC_LEARNERS: raise ValueError( 'When "episodic" is True, "train_learner" should be an episodic one, ' 'among {}.'.format(EPISODIC_LEARNERS)) else: trainer_instance = trainer.BatchTrainer( train_learner, eval_learner, FLAGS.is_training, datasets, FLAGS.train_checkpoint_dir, FLAGS.summary_dir, FLAGS.eval_finegrainedness, FLAGS.eval_finegrainedness_split, FLAGS.eval_imbalance_dataset) if learner_config.train_learner not in BATCH_LEARNERS: raise ValueError( 'When "episodic" is False, "train_learner" should be a batch one, ' 'among {}.'.format(BATCH_LEARNERS)) mode = 'training' if FLAGS.is_training else 'evaluation' tf.logging.info('Starting %s for dataset(s) %s...' % (mode, datasets)) if FLAGS.is_training: trainer_instance.train() elif set(datasets).intersection(trainer.DATASETS_WITH_EXAMPLE_SPLITS): if not data.POOL_SUPPORTED: raise NotImplementedError( 'Example-level splits or pools not supported.') else: if len(datasets) != 1: raise ValueError( 'Requested datasets {} for evaluation, but evaluation ' 'should be performed on individual datasets ' 'only.'.format(datasets)) eval_split = 'test' if FLAGS.eval_finegrainedness: eval_split = FLAGS.eval_finegrainedness_split trainer_instance.evaluate(eval_split) # Flushes the event file to disk and closes the file. trainer_instance.summary_writer.close()
def main(unused_argv): # Parse Gin configurations passed to this script. parse_cmdline_gin_configurations() # Try to reload a previously recorded Gin configuration. # TODO(eringrant): Allow querying of a value to be bound without actually # binding it to avoid the redundant call to `parse_cmdline_gin_configurations` # below. checkpoint_for_eval = gin.query_parameter( 'LearnerConfig.checkpoint_for_eval') if checkpoint_for_eval and FLAGS.reload_eval_checkpoint_gin_config: eval_checkpoint_dir = os.path.dirname(checkpoint_for_eval) load_operative_gin_configurations(eval_checkpoint_dir) # Reload the command-line Gin configuration to allow overriding of the Gin # configuration loaded from the checkpoint directory. parse_cmdline_gin_configurations() # Wrap object instantiations to print out full Gin configuration on failure. try: learner_config = trainer.LearnerConfig() (train_datasets, eval_datasets, restrict_classes, restrict_num_per_class) = trainer.get_datasets_and_restrictions() train_learner = None if FLAGS.is_training or (FLAGS.eval_finegrainedness and FLAGS.eval_finegrainedness_split == 'train'): # If eval_finegrainedness is True, even in pure evaluation mode we still # require a train learner, since we may perform this analysis on the # training sub-graph of ImageNet too. train_learner = trainer.NAME_TO_LEARNER[ learner_config.train_learner] eval_learner = trainer.NAME_TO_LEARNER[learner_config.eval_learner] # Get a trainer or evaluator. trainer_kwargs = { 'train_learner': train_learner, 'eval_learner': eval_learner, 'is_training': FLAGS.is_training, 'train_dataset_list': train_datasets, 'eval_dataset_list': eval_datasets, 'restrict_classes': restrict_classes, 'restrict_num_per_class': restrict_num_per_class, 'checkpoint_dir': FLAGS.train_checkpoint_dir, 'summary_dir': FLAGS.summary_dir, 'records_root_dir': FLAGS.records_root_dir, 'eval_finegrainedness': FLAGS.eval_finegrainedness, 'eval_finegrainedness_split': FLAGS.eval_finegrainedness_split, 'eval_imbalance_dataset': FLAGS.eval_imbalance_dataset, 'omit_from_saving_and_reloading': FLAGS.omit_from_saving_and_reloading, } if learner_config.episodic: trainer_instance = trainer.EpisodicTrainer(**trainer_kwargs) if learner_config.train_learner not in trainer.EPISODIC_LEARNER_NAMES: raise ValueError( 'When "episodic" is True, "train_learner" should be an episodic one, ' 'among {}.'.format(trainer.EPISODIC_LEARNER_NAMES)) else: trainer_instance = trainer.BatchTrainer(**trainer_kwargs) if learner_config.train_learner not in trainer.BATCH_LEARNER_NAMES: raise ValueError( 'When "episodic" is False, "train_learner" should be a batch one, ' 'among {}.'.format(trainer.BATCH_LEARNER_NAMES)) except ValueError as e: logging.info('Full Gin configurations:\n%s', gin.config_str()) raise e # All configurable objects/functions should have been instantiated/called. logging.info('Operative Gin configurations:\n%s', gin.operative_config_str()) if FLAGS.is_training and FLAGS.train_checkpoint_dir: record_operative_gin_configurations(FLAGS.train_checkpoint_dir) datasets = train_datasets if FLAGS.is_training else eval_datasets logging.info('Starting %s for dataset(s) %s...', 'training' if FLAGS.is_training else 'evaluation', datasets) if FLAGS.is_training: trainer_instance.train() elif set(datasets).intersection(trainer.DATASETS_WITH_EXAMPLE_SPLITS): if not data.POOL_SUPPORTED: raise NotImplementedError( 'Example-level splits or pools not supported.') else: if len(datasets) != 1: raise ValueError( 'Requested datasets {} for evaluation, but evaluation ' 'should be performed on individual datasets ' 'only.'.format(datasets)) eval_split = 'test' if FLAGS.eval_finegrainedness: eval_split = FLAGS.eval_finegrainedness_split trainer_instance.evaluate(eval_split) # Flushes the event file to disk and closes the file. if trainer_instance.summary_writer: trainer_instance.summary_writer.close()
def test_episodic_trainer(self): # Inspired from `learn/gin/default/debug_proto_mini_imagenet.gin`, but # building the objects explicitly. learn_config = trainer.LearnConfig( num_updates=100, batch_size=8, # unused num_eval_episodes=10, checkpoint_every=10, validate_every=5, log_every=1, transductive_batch_norm=False, ) learner_config = trainer.LearnerConfig( episodic=True, train_learner='PrototypicalNet', eval_learner='PrototypicalNet', pretrained_checkpoint='', checkpoint_for_eval='', embedding_network='four_layer_convnet', learning_rate=1e-4, decay_learning_rate=True, decay_every=5000, decay_rate=0.5, experiment_name='test', pretrained_source='', ) # PrototypicalNetworkLearner is built automatically and this test does not # have the opportunity to pass values to its constructor except through gin. gin.bind_parameter('PrototypicalNetworkLearner.weight_decay', 1e-4) # Values that can't be passed directly to EpisodeDescriptionConfig gin.bind_parameter('process_episode.support_decoder', decoder.ImageDecoder()) gin.bind_parameter('process_episode.query_decoder', decoder.ImageDecoder()) episode_config = config.EpisodeDescriptionConfig( num_ways=None, num_support=None, num_query=None, min_ways=5, max_ways_upper_bound=50, max_num_query=10, max_support_set_size=500, max_support_size_contrib_per_class=100, min_log_weight=np.log(0.5), max_log_weight=np.log(2), ignore_dag_ontology=False, ignore_bilevel_ontology=False) data_config = config.DataConfig( image_height=84, shuffle_buffer_size=20, read_buffer_size_bytes=(1024**2), num_prefetch=2, ) episodic_trainer = trainer.EpisodicTrainer( train_learner=learner.PrototypicalNetworkLearner, eval_learner=learner.PrototypicalNetworkLearner, is_training=True, train_dataset_list=['mini_imagenet'], eval_dataset_list=['mini_imagenet'], restrict_classes={}, restrict_num_per_class={}, checkpoint_dir='', summary_dir='', records_root_dir=FLAGS.records_root_dir, eval_finegrainedness=False, eval_finegrainedness_split='', eval_imbalance_dataset='', omit_from_saving_and_reloading='', train_episode_config=episode_config, eval_episode_config=episode_config, learn_config=learn_config, learner_config=learner_config, data_config=data_config, ) # Get the next train / valid / test episodes. train_episode = episodic_trainer.get_next(trainer.TRAIN_SPLIT) self.assertIsInstance(train_episode, providers.EpisodeDataset) # This isn't really a test. It just checks that things don't crash... print( episodic_trainer.sess.run([ episodic_trainer.train_op, episodic_trainer.losses[trainer.TRAIN_SPLIT], episodic_trainer.accs[trainer.TRAIN_SPLIT] ]))