def main(unused_argv): tf.logging.info('FLAGS.gin_config: %s', FLAGS.gin_config) tf.logging.info('FLAGS.gin_bindings: %s', FLAGS.gin_bindings) gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) learner_config = trainer.LearnerConfig() # Check for inconsistent or contradictory flag settings. if (learner_config.checkpoint_for_eval and learner_config.pretrained_checkpoint): raise ValueError( 'Can not define both checkpoint_for_eval and ' 'pretrained_checkpoint. The difference between them is ' 'that in the former all variables are restored (including ' 'global step) whereas the latter is only applicable to ' 'the start of training for initializing the model from ' 'pre-trained weights. It is also only applicable to ' 'episodic models and restores only the embedding weights.') (train_datasets, eval_datasets, restrict_num_per_class) = trainer.get_datasets_and_restrictions() train_learner = None if FLAGS.is_training or (FLAGS.eval_finegrainedness and FLAGS.eval_finegrainedness_split == 'train'): # If eval_finegrainedness is True, even in pure evaluation mode we still # require a train learner, since we may perform this analysis on the # training sub-graph of ImageNet too. train_learner = NAME_TO_LEARNER[learner_config.train_learner] eval_learner = NAME_TO_LEARNER[learner_config.eval_learner] # Get a trainer or evaluator. trainer_kwargs = { 'train_learner': train_learner, 'eval_learner': eval_learner, 'is_training': FLAGS.is_training, 'train_dataset_list': train_datasets, 'eval_dataset_list': eval_datasets, 'restrict_num_per_class': restrict_num_per_class, 'checkpoint_dir': FLAGS.train_checkpoint_dir, 'summary_dir': FLAGS.summary_dir, 'records_root_dir': FLAGS.records_root_dir, 'eval_finegrainedness': FLAGS.eval_finegrainedness, 'eval_finegrainedness_split': FLAGS.eval_finegrainedness_split, 'eval_imbalance_dataset': FLAGS.eval_imbalance_dataset, 'omit_from_saving_and_reloading': FLAGS.omit_from_saving_and_reloading, } if learner_config.episodic: trainer_instance = trainer.EpisodicTrainer(**trainer_kwargs) if learner_config.train_learner not in EPISODIC_LEARNERS: raise ValueError( 'When "episodic" is True, "train_learner" should be an episodic one, ' 'among {}.'.format(EPISODIC_LEARNERS)) else: trainer_instance = trainer.BatchTrainer(**trainer_kwargs) if learner_config.train_learner not in BATCH_LEARNERS: raise ValueError( 'When "episodic" is False, "train_learner" should be a batch one, ' 'among {}.'.format(BATCH_LEARNERS)) mode = 'training' if FLAGS.is_training else 'evaluation' datasets = train_datasets if FLAGS.is_training else eval_datasets tf.logging.info('Starting %s for dataset(s) %s...' % (mode, datasets)) # Record gin operative config string after the setup, both in the logs and in # the checkpoint directory. gin_operative_config = gin.operative_config_str() tf.logging.info('gin configuration:\n%s', gin_operative_config) if FLAGS.train_checkpoint_dir: gin_log_file = os.path.join(FLAGS.train_checkpoint_dir, 'operative_config.gin') # If it exists already, rename it instead of overwriting it. # This just saves the previous one, not all the ones before. if tf.gfile.Exists(gin_log_file): tf.gfile.Rename(gin_log_file, gin_log_file + '.old', overwrite=True) with tf.io.gfile.GFile(gin_log_file, 'w') as f: f.write(gin_operative_config) if FLAGS.is_training: trainer_instance.train() elif set(datasets).intersection(trainer.DATASETS_WITH_EXAMPLE_SPLITS): if not data.POOL_SUPPORTED: raise NotImplementedError( 'Example-level splits or pools not supported.') else: if len(datasets) != 1: raise ValueError( 'Requested datasets {} for evaluation, but evaluation ' 'should be performed on individual datasets ' 'only.'.format(datasets)) eval_split = 'test' if FLAGS.eval_finegrainedness: eval_split = FLAGS.eval_finegrainedness_split trainer_instance.evaluate(eval_split) # Flushes the event file to disk and closes the file. if trainer_instance.summary_writer: trainer_instance.summary_writer.close()
def main(unused_argv): # Parse Gin configurations passed to this script. parse_cmdline_gin_configurations() if FLAGS.reload_checkpoint_gin_config: # Try to reload a previously recorded Gin configuration. # TODO(eringrant): Allow querying of a value to be bound without binding it # to avoid the redundant call to `parse_cmdline_gin_configurations` below. try: checkpoint_to_reload = gin.query_parameter( 'Trainer.checkpoint_for_eval') except ValueError: try: checkpoint_to_reload = gin.query_parameter( 'Trainer.pretrained_checkpoint') except ValueError: checkpoint_to_reload = None # Load the operative Gin configurations from the checkpoint directory. if checkpoint_to_reload: reload_checkpoint_dir = os.path.dirname(checkpoint_to_reload) load_operative_gin_configurations(reload_checkpoint_dir) # Reload the command-line Gin configuration to allow overriding of the Gin # configuration loaded from the checkpoint directory. parse_cmdline_gin_configurations() # Wrap object instantiations to print out full Gin configuration on failure. try: (train_datasets, eval_datasets, restrict_classes, restrict_num_per_class) = trainer.get_datasets_and_restrictions() # Get a trainer or evaluator. trainer_kwargs = { 'is_training': FLAGS.is_training, 'train_dataset_list': train_datasets, 'eval_dataset_list': eval_datasets, 'restrict_classes': restrict_classes, 'restrict_num_per_class': restrict_num_per_class, 'checkpoint_dir': FLAGS.train_checkpoint_dir, 'summary_dir': FLAGS.summary_dir, 'records_root_dir': FLAGS.records_root_dir, 'eval_finegrainedness': FLAGS.eval_finegrainedness, 'eval_finegrainedness_split': FLAGS.eval_finegrainedness_split, 'eval_imbalance_dataset': FLAGS.eval_imbalance_dataset, 'omit_from_saving_and_reloading': FLAGS.omit_from_saving_and_reloading, } train_learner_class = gin.query_parameter( 'Trainer.train_learner_class') if gin.query_parameter('Trainer.episodic'): trainer_instance = trainer.EpisodicTrainer(**trainer_kwargs) if train_learner_class not in trainer.EPISODIC_LEARNER_NAMES: raise ValueError( 'When "episodic" is True, "train_learner" should be an episodic ' 'one, among {}, but received {}.'.format( trainer.EPISODIC_LEARNER_NAMES, train_learner_class)) else: trainer_instance = trainer.BatchTrainer(**trainer_kwargs) if train_learner_class not in trainer.BATCH_LEARNER_NAMES: raise ValueError( 'When `episodic` is False, `train_learner` should be a batch one, ' 'among {}, but received {}.'.format( trainer.BATCH_LEARNER_NAMES, train_learner_class)) except ValueError as e: logging.info('Full Gin configurations:\n%s', gin.config_str()) raise e # All configurable objects/functions should have been instantiated/called. logging.info('Operative Gin configurations:\n%s', gin.operative_config_str()) if FLAGS.is_training and FLAGS.train_checkpoint_dir: record_operative_gin_configurations(FLAGS.train_checkpoint_dir) # TODO(all) Improve saving of gin configs (during train and eval). # Above, handles training only and now below is a hack for evaluation. elif not FLAGS.is_training and FLAGS.summary_dir: record_operative_gin_configurations(FLAGS.summary_dir) datasets = train_datasets if FLAGS.is_training else eval_datasets logging.info('Starting %s for dataset(s) %s...', 'training' if FLAGS.is_training else 'evaluation', datasets) if FLAGS.is_training: trainer_instance.train() elif set(datasets).intersection(trainer.DATASETS_WITH_EXAMPLE_SPLITS): if not data.POOL_SUPPORTED: raise NotImplementedError( 'Example-level splits or pools not supported.') else: if len(datasets) != 1: raise ValueError( 'Requested datasets {} for evaluation, but evaluation ' 'should be performed on individual datasets ' 'only.'.format(datasets)) eval_split = trainer.TEST_SPLIT if FLAGS.eval_finegrainedness: eval_split = FLAGS.eval_finegrainedness_split _, _, acc_summary, ci_acc_summary = trainer_instance.evaluate( eval_split) if trainer_instance.summary_writer: trainer_instance.summary_writer.add_summary(acc_summary) trainer_instance.summary_writer.add_summary(ci_acc_summary) # Flushes the event file to disk and closes the file. if trainer_instance.summary_writer: trainer_instance.summary_writer.close()
def main(unused_argv): # Parse Gin configurations passed to this script. parse_cmdline_gin_configurations() # Try to reload a previously recorded Gin configuration. # TODO(eringrant): Allow querying of a value to be bound without actually # binding it to avoid the redundant call to `parse_cmdline_gin_configurations` # below. checkpoint_for_eval = gin.query_parameter( 'LearnerConfig.checkpoint_for_eval') if checkpoint_for_eval and FLAGS.reload_eval_checkpoint_gin_config: eval_checkpoint_dir = os.path.dirname(checkpoint_for_eval) load_operative_gin_configurations(eval_checkpoint_dir) # Reload the command-line Gin configuration to allow overriding of the Gin # configuration loaded from the checkpoint directory. parse_cmdline_gin_configurations() # Wrap object instantiations to print out full Gin configuration on failure. try: learner_config = trainer.LearnerConfig() (train_datasets, eval_datasets, restrict_classes, restrict_num_per_class) = trainer.get_datasets_and_restrictions() train_learner = None if FLAGS.is_training or (FLAGS.eval_finegrainedness and FLAGS.eval_finegrainedness_split == 'train'): # If eval_finegrainedness is True, even in pure evaluation mode we still # require a train learner, since we may perform this analysis on the # training sub-graph of ImageNet too. train_learner = trainer.NAME_TO_LEARNER[ learner_config.train_learner] eval_learner = trainer.NAME_TO_LEARNER[learner_config.eval_learner] # Get a trainer or evaluator. trainer_kwargs = { 'train_learner': train_learner, 'eval_learner': eval_learner, 'is_training': FLAGS.is_training, 'train_dataset_list': train_datasets, 'eval_dataset_list': eval_datasets, 'restrict_classes': restrict_classes, 'restrict_num_per_class': restrict_num_per_class, 'checkpoint_dir': FLAGS.train_checkpoint_dir, 'summary_dir': FLAGS.summary_dir, 'records_root_dir': FLAGS.records_root_dir, 'eval_finegrainedness': FLAGS.eval_finegrainedness, 'eval_finegrainedness_split': FLAGS.eval_finegrainedness_split, 'eval_imbalance_dataset': FLAGS.eval_imbalance_dataset, 'omit_from_saving_and_reloading': FLAGS.omit_from_saving_and_reloading, } if learner_config.episodic: trainer_instance = trainer.EpisodicTrainer(**trainer_kwargs) if learner_config.train_learner not in trainer.EPISODIC_LEARNER_NAMES: raise ValueError( 'When "episodic" is True, "train_learner" should be an episodic one, ' 'among {}.'.format(trainer.EPISODIC_LEARNER_NAMES)) else: trainer_instance = trainer.BatchTrainer(**trainer_kwargs) if learner_config.train_learner not in trainer.BATCH_LEARNER_NAMES: raise ValueError( 'When "episodic" is False, "train_learner" should be a batch one, ' 'among {}.'.format(trainer.BATCH_LEARNER_NAMES)) except ValueError as e: logging.info('Full Gin configurations:\n%s', gin.config_str()) raise e # All configurable objects/functions should have been instantiated/called. logging.info('Operative Gin configurations:\n%s', gin.operative_config_str()) if FLAGS.is_training and FLAGS.train_checkpoint_dir: record_operative_gin_configurations(FLAGS.train_checkpoint_dir) datasets = train_datasets if FLAGS.is_training else eval_datasets logging.info('Starting %s for dataset(s) %s...', 'training' if FLAGS.is_training else 'evaluation', datasets) if FLAGS.is_training: trainer_instance.train() elif set(datasets).intersection(trainer.DATASETS_WITH_EXAMPLE_SPLITS): if not data.POOL_SUPPORTED: raise NotImplementedError( 'Example-level splits or pools not supported.') else: if len(datasets) != 1: raise ValueError( 'Requested datasets {} for evaluation, but evaluation ' 'should be performed on individual datasets ' 'only.'.format(datasets)) eval_split = 'test' if FLAGS.eval_finegrainedness: eval_split = FLAGS.eval_finegrainedness_split trainer_instance.evaluate(eval_split) # Flushes the event file to disk and closes the file. if trainer_instance.summary_writer: trainer_instance.summary_writer.close()
def main(unused_argv): gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) learner_config = trainer.LearnerConfig() # Check for inconsistent or contradictory flag settings. if (learner_config.checkpoint_for_eval and learner_config.pretrained_checkpoint): raise ValueError( 'Can not define both checkpoint_for_eval and ' 'pretrained_checkpoint. The difference between them is ' 'that in the former all variables are restored (including ' 'global step) whereas the latter is only applicable to ' 'the start of training for initializing the model from ' 'pre-trained weights. It is also only applicable to ' 'episodic models and restores only the embedding weights.') datasets = get_datasets() train_learner = None if FLAGS.is_training or (FLAGS.eval_finegrainedness and FLAGS.eval_finegrainedness_split == 'train'): # If eval_finegrainedness is True, even in pure evaluation mode we still # require a train learner, since we may perform this analysis on the # training sub-graph of ImageNet too. train_learner = NAME_TO_LEARNER[learner_config.train_learner] eval_learner = NAME_TO_LEARNER[learner_config.eval_learner] # Get a trainer or evaluator. if learner_config.episodic: trainer_instance = trainer.EpisodicTrainer( train_learner, eval_learner, FLAGS.is_training, datasets, FLAGS.train_checkpoint_dir, FLAGS.summary_dir, FLAGS.eval_finegrainedness, FLAGS.eval_finegrainedness_split, FLAGS.eval_imbalance_dataset) if learner_config.train_learner not in EPISODIC_LEARNERS: raise ValueError( 'When "episodic" is True, "train_learner" should be an episodic one, ' 'among {}.'.format(EPISODIC_LEARNERS)) else: trainer_instance = trainer.BatchTrainer( train_learner, eval_learner, FLAGS.is_training, datasets, FLAGS.train_checkpoint_dir, FLAGS.summary_dir, FLAGS.eval_finegrainedness, FLAGS.eval_finegrainedness_split, FLAGS.eval_imbalance_dataset) if learner_config.train_learner not in BATCH_LEARNERS: raise ValueError( 'When "episodic" is False, "train_learner" should be a batch one, ' 'among {}.'.format(BATCH_LEARNERS)) mode = 'training' if FLAGS.is_training else 'evaluation' tf.logging.info('Starting %s for dataset(s) %s...' % (mode, datasets)) if FLAGS.is_training: trainer_instance.train() elif set(datasets).intersection(trainer.DATASETS_WITH_EXAMPLE_SPLITS): if not data.POOL_SUPPORTED: raise NotImplementedError( 'Example-level splits or pools not supported.') else: if len(datasets) != 1: raise ValueError( 'Requested datasets {} for evaluation, but evaluation ' 'should be performed on individual datasets ' 'only.'.format(datasets)) eval_split = 'test' if FLAGS.eval_finegrainedness: eval_split = FLAGS.eval_finegrainedness_split trainer_instance.evaluate(eval_split) # Flushes the event file to disk and closes the file. trainer_instance.summary_writer.close()