def test_episodic_overfit(self, learner_class, learner_config, threshold=1., attempts=1): """Test that error goes down when training on a single episode. This can help check that the trained model and the evaluated one share the trainable parameters correctly. Args: learner_class: A subclass of Learner. learner_config: A string, the Learner-specific gin configuration string. threshold: A float (default 1.), the performance to reach at least once. attempts: An int (default 1), how many of the last steps should be checked when looking for a validation value reaching the threshold (default 1). """ gin_config = '\n'.join((self.BASE_GIN_CONFIG, learner_config)) gin.parse_config(gin_config) episode_config = config.EpisodeDescriptionConfig( num_ways=self.NUM_EXAMPLES, num_support=1, num_query=1) trainer_instance = trainer.Trainer( train_learner_class=learner_class, eval_learner_class=learner_class, is_training=True, train_dataset_list=['dummy'], eval_dataset_list=['dummy'], records_root_dir=self.temp_dir, checkpoint_dir=os.path.join(self.temp_dir, 'checkpoints'), train_episode_config=episode_config, eval_episode_config=episode_config, data_config=config.DataConfig(), # BEGIN GOOGLE_INTERNAL real_episodes=False, real_episodes_results_dir='', # END GOOGLE_INTERNAL ) # Train 1 update at a time for the last `attempts - 1` steps. trainer_instance.num_updates -= (attempts - 1) trainer_instance.train() valid_accs = [trainer_instance.valid_acc] for _ in range(attempts - 1): trainer_instance.num_updates += 1 trainer_instance.train() valid_accs.append(trainer_instance.valid_acc) self.assertGreaterEqual(max(valid_accs), threshold)
def test_trainer(self): # PrototypicalNetworkLearner is built automatically and this test does not # have the opportunity to pass values to its constructor except through gin. gin.bind_parameter('PrototypicalNetworkLearner.weight_decay', 1e-4) gin.bind_parameter( 'PrototypicalNetworkLearner.backprop_through_moments', True) gin.bind_parameter( 'PrototypicalNetworkLearner.transductive_batch_norm', False) gin.bind_parameter('PrototypicalNetworkLearner.embedding_fn', 'four_layer_convnet') # Values that can't be passed directly to EpisodeDescriptionConfig gin.bind_parameter('process_episode.support_decoder', decoder.ImageDecoder()) gin.bind_parameter('process_episode.query_decoder', decoder.ImageDecoder()) episode_config = config.EpisodeDescriptionConfig( num_ways=None, num_support=None, num_query=None, min_ways=5, max_ways_upper_bound=50, max_num_query=10, max_support_set_size=500, max_support_size_contrib_per_class=100, min_log_weight=np.log(0.5), max_log_weight=np.log(2), ignore_dag_ontology=False, ignore_bilevel_ontology=False) # Inspired from `learn/gin/default/debug_proto_mini_imagenet.gin`, but # building the objects explicitly. data_config = config.DataConfig( image_height=84, shuffle_buffer_size=20, read_buffer_size_bytes=(1024**2), num_prefetch=2, ) trainer_instance = trainer.Trainer( train_learner_class=learner_lib.PrototypicalNetworkLearner, eval_learner_class=learner_lib.PrototypicalNetworkLearner, is_training=True, train_dataset_list=['mini_imagenet'], eval_dataset_list=['mini_imagenet'], restrict_classes={}, restrict_num_per_class={}, checkpoint_dir='', summary_dir='', records_root_dir=FLAGS.records_root_dir, eval_split=trainer.VALID_SPLIT, eval_finegrainedness=False, eval_finegrainedness_split='', eval_imbalance_dataset='', omit_from_saving_and_reloading='', train_episode_config=episode_config, eval_episode_config=episode_config, data_config=data_config, num_updates=100, batch_size=8, # unused num_eval_episodes=10, checkpoint_every=10, validate_every=5, log_every=1, checkpoint_to_restore=None, learning_rate=1e-4, decay_learning_rate=True, decay_every=5000, decay_rate=0.5, experiment_name='test', pretrained_source='', ) # Get the next train / valid / test episodes. train_episode = trainer_instance.next_data[trainer.TRAIN_SPLIT] self.assertIsInstance(train_episode, providers.EpisodeDataset) # This isn't really a test. It just checks that things don't crash... print( trainer_instance.sess.run([ trainer_instance.train_op, trainer_instance.losses[trainer.TRAIN_SPLIT], trainer_instance.accuracies[trainer.TRAIN_SPLIT] ]))
def main(unused_argv): # Parse Gin configurations passed to this script. parse_cmdline_gin_configurations() if FLAGS.reload_checkpoint_gin_config: # Try to reload a previously recorded Gin configuration from an operative # Gin configuration file in one of the provided directories. # TODO(eringrant): Allow querying of a value to be bound without binding it # to avoid the redundant call to `parse_cmdline_gin_configurations` below. try: checkpoint_to_restore = gin.query_parameter( 'Trainer.checkpoint_to_restore') except ValueError: checkpoint_to_restore = None # Load the operative Gin configurations from the checkpoint directory. if checkpoint_to_restore: restore_checkpoint_dir = os.path.dirname(checkpoint_to_restore) load_operative_gin_configurations(restore_checkpoint_dir) # Reload the command-line Gin configuration to allow overriding of the Gin # configuration loaded from the checkpoint directory. parse_cmdline_gin_configurations() # Wrap object instantiations to print out full Gin configuration on failure. try: (train_datasets, eval_datasets, restrict_classes, restrict_num_per_class) = trainer.get_datasets_and_restrictions() # Get a trainer or evaluator. trainer_instance = trainer.Trainer( is_training=FLAGS.is_training, train_dataset_list=train_datasets, eval_dataset_list=eval_datasets, restrict_classes=restrict_classes, restrict_num_per_class=restrict_num_per_class, checkpoint_dir=FLAGS.train_checkpoint_dir, summary_dir=FLAGS.summary_dir, records_root_dir=FLAGS.records_root_dir, eval_finegrainedness=FLAGS.eval_finegrainedness, eval_finegrainedness_split=FLAGS.eval_finegrainedness_split, eval_imbalance_dataset=FLAGS.eval_imbalance_dataset, omit_from_saving_and_reloading=FLAGS. omit_from_saving_and_reloading, eval_split=FLAGS.eval_split, ) except ValueError as e: logging.info('Full Gin configurations:\n%s', gin.config_str()) raise e # All configurable objects/functions should have been instantiated/called. # TODO(evcu): Tie saving of Gin configuration at training and evaluation time. logging.info('Operative Gin configurations:\n%s', gin.operative_config_str()) if FLAGS.is_training and FLAGS.train_checkpoint_dir: record_operative_gin_configurations(FLAGS.train_checkpoint_dir) elif not FLAGS.is_training and FLAGS.summary_dir: record_operative_gin_configurations(FLAGS.summary_dir) datasets = train_datasets if FLAGS.is_training else eval_datasets logging.info('Starting %s for dataset(s) %s...', 'training' if FLAGS.is_training else 'evaluation', datasets) if FLAGS.is_training: trainer_instance.train() elif set(datasets).intersection(trainer.DATASETS_WITH_EXAMPLE_SPLITS): if not data.POOL_SUPPORTED: raise NotImplementedError( 'Example-level splits or pools not supported.') else: if len(datasets) != 1: raise ValueError( 'Requested datasets {} for evaluation, but evaluation ' 'should be performed on individual datasets ' 'only.'.format(datasets)) if FLAGS.eval_finegrainedness: eval_split = FLAGS.eval_finegrainedness_split elif FLAGS.eval_split: eval_split = FLAGS.eval_split else: eval_split = trainer.TEST_SPLIT _, _, acc_summary, ci_acc_summary = trainer_instance.evaluate( eval_split) if trainer_instance.summary_writer: trainer_instance.summary_writer.add_summary(acc_summary) trainer_instance.summary_writer.add_summary(ci_acc_summary) # Flushes the event file to disk and closes the file. if trainer_instance.summary_writer: trainer_instance.summary_writer.close()