Exemplo n.º 1
0
def main(unused_argv):
    tf.logging.info('FLAGS.gin_config: %s', FLAGS.gin_config)
    tf.logging.info('FLAGS.gin_bindings: %s', FLAGS.gin_bindings)
    gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)

    learner_config = trainer.LearnerConfig()

    # Check for inconsistent or contradictory flag settings.
    if (learner_config.checkpoint_for_eval
            and learner_config.pretrained_checkpoint):
        raise ValueError(
            'Can not define both checkpoint_for_eval and '
            'pretrained_checkpoint. The difference between them is '
            'that in the former all variables are restored (including '
            'global step) whereas the latter is only applicable to '
            'the start of training for initializing the model from '
            'pre-trained weights. It is also only applicable to '
            'episodic models and restores only the embedding weights.')

    (train_datasets, eval_datasets,
     restrict_num_per_class) = trainer.get_datasets_and_restrictions()

    train_learner = None
    if FLAGS.is_training or (FLAGS.eval_finegrainedness
                             and FLAGS.eval_finegrainedness_split == 'train'):
        # If eval_finegrainedness is True, even in pure evaluation mode we still
        # require a train learner, since we may perform this analysis on the
        # training sub-graph of ImageNet too.
        train_learner = NAME_TO_LEARNER[learner_config.train_learner]
    eval_learner = NAME_TO_LEARNER[learner_config.eval_learner]

    # Get a trainer or evaluator.
    trainer_kwargs = {
        'train_learner': train_learner,
        'eval_learner': eval_learner,
        'is_training': FLAGS.is_training,
        'train_dataset_list': train_datasets,
        'eval_dataset_list': eval_datasets,
        'restrict_num_per_class': restrict_num_per_class,
        'checkpoint_dir': FLAGS.train_checkpoint_dir,
        'summary_dir': FLAGS.summary_dir,
        'records_root_dir': FLAGS.records_root_dir,
        'eval_finegrainedness': FLAGS.eval_finegrainedness,
        'eval_finegrainedness_split': FLAGS.eval_finegrainedness_split,
        'eval_imbalance_dataset': FLAGS.eval_imbalance_dataset,
        'omit_from_saving_and_reloading': FLAGS.omit_from_saving_and_reloading,
    }
    if learner_config.episodic:
        trainer_instance = trainer.EpisodicTrainer(**trainer_kwargs)
        if learner_config.train_learner not in EPISODIC_LEARNERS:
            raise ValueError(
                'When "episodic" is True, "train_learner" should be an episodic one, '
                'among {}.'.format(EPISODIC_LEARNERS))
    else:
        trainer_instance = trainer.BatchTrainer(**trainer_kwargs)
        if learner_config.train_learner not in BATCH_LEARNERS:
            raise ValueError(
                'When "episodic" is False, "train_learner" should be a batch one, '
                'among {}.'.format(BATCH_LEARNERS))

    mode = 'training' if FLAGS.is_training else 'evaluation'
    datasets = train_datasets if FLAGS.is_training else eval_datasets
    tf.logging.info('Starting %s for dataset(s) %s...' % (mode, datasets))

    # Record gin operative config string after the setup, both in the logs and in
    # the checkpoint directory.
    gin_operative_config = gin.operative_config_str()
    tf.logging.info('gin configuration:\n%s', gin_operative_config)
    if FLAGS.train_checkpoint_dir:
        gin_log_file = os.path.join(FLAGS.train_checkpoint_dir,
                                    'operative_config.gin')
        # If it exists already, rename it instead of overwriting it.
        # This just saves the previous one, not all the ones before.
        if tf.gfile.Exists(gin_log_file):
            tf.gfile.Rename(gin_log_file,
                            gin_log_file + '.old',
                            overwrite=True)
        with tf.io.gfile.GFile(gin_log_file, 'w') as f:
            f.write(gin_operative_config)

    if FLAGS.is_training:
        trainer_instance.train()
    elif set(datasets).intersection(trainer.DATASETS_WITH_EXAMPLE_SPLITS):
        if not data.POOL_SUPPORTED:
            raise NotImplementedError(
                'Example-level splits or pools not supported.')
    else:
        if len(datasets) != 1:
            raise ValueError(
                'Requested datasets {} for evaluation, but evaluation '
                'should be performed on individual datasets '
                'only.'.format(datasets))

        eval_split = 'test'
        if FLAGS.eval_finegrainedness:
            eval_split = FLAGS.eval_finegrainedness_split

        trainer_instance.evaluate(eval_split)

    # Flushes the event file to disk and closes the file.
    if trainer_instance.summary_writer:
        trainer_instance.summary_writer.close()
Exemplo n.º 2
0
def main(unused_argv):

    # Parse Gin configurations passed to this script.
    parse_cmdline_gin_configurations()

    if FLAGS.reload_checkpoint_gin_config:
        # Try to reload a previously recorded Gin configuration from an operative
        # Gin configuration file in one of the provided directories.
        # TODO(eringrant): Allow querying of a value to be bound without binding it
        # to avoid the redundant call to `parse_cmdline_gin_configurations` below.
        try:
            checkpoint_to_restore = gin.query_parameter(
                'Trainer.checkpoint_to_restore')
        except ValueError:
            checkpoint_to_restore = None

        # Load the operative Gin configurations from the checkpoint directory.
        if checkpoint_to_restore:
            restore_checkpoint_dir = os.path.dirname(checkpoint_to_restore)
            load_operative_gin_configurations(restore_checkpoint_dir)

            # Reload the command-line Gin configuration to allow overriding of the Gin
            # configuration loaded from the checkpoint directory.
            parse_cmdline_gin_configurations()

    # Wrap object instantiations to print out full Gin configuration on failure.
    try:
        (train_datasets, eval_datasets, restrict_classes,
         restrict_num_per_class) = trainer.get_datasets_and_restrictions()

        # Get a trainer or evaluator.
        trainer_instance = trainer.Trainer(
            is_training=FLAGS.is_training,
            train_dataset_list=train_datasets,
            eval_dataset_list=eval_datasets,
            restrict_classes=restrict_classes,
            restrict_num_per_class=restrict_num_per_class,
            checkpoint_dir=FLAGS.train_checkpoint_dir,
            summary_dir=FLAGS.summary_dir,
            records_root_dir=FLAGS.records_root_dir,
            eval_finegrainedness=FLAGS.eval_finegrainedness,
            eval_finegrainedness_split=FLAGS.eval_finegrainedness_split,
            eval_imbalance_dataset=FLAGS.eval_imbalance_dataset,
            omit_from_saving_and_reloading=FLAGS.
            omit_from_saving_and_reloading,
            eval_split=FLAGS.eval_split,
        )
    except ValueError as e:
        logging.info('Full Gin configurations:\n%s', gin.config_str())
        raise e

    # All configurable objects/functions should have been instantiated/called.
    # TODO(evcu): Tie saving of Gin configuration at training and evaluation time.
    logging.info('Operative Gin configurations:\n%s',
                 gin.operative_config_str())
    if FLAGS.is_training and FLAGS.train_checkpoint_dir:
        record_operative_gin_configurations(FLAGS.train_checkpoint_dir)
    elif not FLAGS.is_training and FLAGS.summary_dir:
        record_operative_gin_configurations(FLAGS.summary_dir)

    datasets = train_datasets if FLAGS.is_training else eval_datasets
    logging.info('Starting %s for dataset(s) %s...',
                 'training' if FLAGS.is_training else 'evaluation', datasets)
    if FLAGS.is_training:
        trainer_instance.train()
    elif set(datasets).intersection(trainer.DATASETS_WITH_EXAMPLE_SPLITS):
        if not data.POOL_SUPPORTED:
            raise NotImplementedError(
                'Example-level splits or pools not supported.')
    else:
        if len(datasets) != 1:
            raise ValueError(
                'Requested datasets {} for evaluation, but evaluation '
                'should be performed on individual datasets '
                'only.'.format(datasets))

        if FLAGS.eval_finegrainedness:
            eval_split = FLAGS.eval_finegrainedness_split
        elif FLAGS.eval_split:
            eval_split = FLAGS.eval_split
        else:
            eval_split = trainer.TEST_SPLIT

        _, _, acc_summary, ci_acc_summary = trainer_instance.evaluate(
            eval_split)
        if trainer_instance.summary_writer:
            trainer_instance.summary_writer.add_summary(acc_summary)
            trainer_instance.summary_writer.add_summary(ci_acc_summary)

    # Flushes the event file to disk and closes the file.
    if trainer_instance.summary_writer:
        trainer_instance.summary_writer.close()
Exemplo n.º 3
0
def main(unused_argv):

    # Parse Gin configurations passed to this script.
    parse_cmdline_gin_configurations()

    # Try to reload a previously recorded Gin configuration.
    # TODO(eringrant): Allow querying of a value to be bound without actually
    # binding it to avoid the redundant call to `parse_cmdline_gin_configurations`
    # below.
    checkpoint_for_eval = gin.query_parameter(
        'LearnerConfig.checkpoint_for_eval')
    if checkpoint_for_eval and FLAGS.reload_eval_checkpoint_gin_config:
        eval_checkpoint_dir = os.path.dirname(checkpoint_for_eval)
        load_operative_gin_configurations(eval_checkpoint_dir)

        # Reload the command-line Gin configuration to allow overriding of the Gin
        # configuration loaded from the checkpoint directory.
        parse_cmdline_gin_configurations()

    # Wrap object instantiations to print out full Gin configuration on failure.
    try:
        learner_config = trainer.LearnerConfig()

        (train_datasets, eval_datasets, restrict_classes,
         restrict_num_per_class) = trainer.get_datasets_and_restrictions()

        train_learner = None
        if FLAGS.is_training or (FLAGS.eval_finegrainedness and
                                 FLAGS.eval_finegrainedness_split == 'train'):
            # If eval_finegrainedness is True, even in pure evaluation mode we still
            # require a train learner, since we may perform this analysis on the
            # training sub-graph of ImageNet too.
            train_learner = trainer.NAME_TO_LEARNER[
                learner_config.train_learner]
        eval_learner = trainer.NAME_TO_LEARNER[learner_config.eval_learner]

        # Get a trainer or evaluator.
        trainer_kwargs = {
            'train_learner': train_learner,
            'eval_learner': eval_learner,
            'is_training': FLAGS.is_training,
            'train_dataset_list': train_datasets,
            'eval_dataset_list': eval_datasets,
            'restrict_classes': restrict_classes,
            'restrict_num_per_class': restrict_num_per_class,
            'checkpoint_dir': FLAGS.train_checkpoint_dir,
            'summary_dir': FLAGS.summary_dir,
            'records_root_dir': FLAGS.records_root_dir,
            'eval_finegrainedness': FLAGS.eval_finegrainedness,
            'eval_finegrainedness_split': FLAGS.eval_finegrainedness_split,
            'eval_imbalance_dataset': FLAGS.eval_imbalance_dataset,
            'omit_from_saving_and_reloading':
            FLAGS.omit_from_saving_and_reloading,
        }
        if learner_config.episodic:
            trainer_instance = trainer.EpisodicTrainer(**trainer_kwargs)
            if learner_config.train_learner not in trainer.EPISODIC_LEARNER_NAMES:
                raise ValueError(
                    'When "episodic" is True, "train_learner" should be an episodic one, '
                    'among {}.'.format(trainer.EPISODIC_LEARNER_NAMES))
        else:
            trainer_instance = trainer.BatchTrainer(**trainer_kwargs)
            if learner_config.train_learner not in trainer.BATCH_LEARNER_NAMES:
                raise ValueError(
                    'When "episodic" is False, "train_learner" should be a batch one, '
                    'among {}.'.format(trainer.BATCH_LEARNER_NAMES))

    except ValueError as e:
        logging.info('Full Gin configurations:\n%s', gin.config_str())
        raise e

    # All configurable objects/functions should have been instantiated/called.
    logging.info('Operative Gin configurations:\n%s',
                 gin.operative_config_str())
    if FLAGS.is_training and FLAGS.train_checkpoint_dir:
        record_operative_gin_configurations(FLAGS.train_checkpoint_dir)

    datasets = train_datasets if FLAGS.is_training else eval_datasets
    logging.info('Starting %s for dataset(s) %s...',
                 'training' if FLAGS.is_training else 'evaluation', datasets)
    if FLAGS.is_training:
        trainer_instance.train()
    elif set(datasets).intersection(trainer.DATASETS_WITH_EXAMPLE_SPLITS):
        if not data.POOL_SUPPORTED:
            raise NotImplementedError(
                'Example-level splits or pools not supported.')
    else:
        if len(datasets) != 1:
            raise ValueError(
                'Requested datasets {} for evaluation, but evaluation '
                'should be performed on individual datasets '
                'only.'.format(datasets))

        eval_split = 'test'
        if FLAGS.eval_finegrainedness:
            eval_split = FLAGS.eval_finegrainedness_split

        trainer_instance.evaluate(eval_split)

    # Flushes the event file to disk and closes the file.
    if trainer_instance.summary_writer:
        trainer_instance.summary_writer.close()
Exemplo n.º 4
0
def main(unused_argv):

    # Parse Gin configurations passed to this script.
    parse_cmdline_gin_configurations()

    if FLAGS.reload_checkpoint_gin_config:
        # Try to reload a previously recorded Gin configuration.
        # TODO(eringrant): Allow querying of a value to be bound without binding it
        # to avoid the redundant call to `parse_cmdline_gin_configurations` below.
        try:
            checkpoint_to_reload = gin.query_parameter(
                'Trainer.checkpoint_for_eval')
        except ValueError:
            try:
                checkpoint_to_reload = gin.query_parameter(
                    'Trainer.pretrained_checkpoint')
            except ValueError:
                checkpoint_to_reload = None

        # Load the operative Gin configurations from the checkpoint directory.
        if checkpoint_to_reload:
            reload_checkpoint_dir = os.path.dirname(checkpoint_to_reload)
            load_operative_gin_configurations(reload_checkpoint_dir)

            # Reload the command-line Gin configuration to allow overriding of the Gin
            # configuration loaded from the checkpoint directory.
            parse_cmdline_gin_configurations()

    # Wrap object instantiations to print out full Gin configuration on failure.
    try:
        (train_datasets, eval_datasets, restrict_classes,
         restrict_num_per_class) = trainer.get_datasets_and_restrictions()

        # Get a trainer or evaluator.
        trainer_kwargs = {
            'is_training': FLAGS.is_training,
            'train_dataset_list': train_datasets,
            'eval_dataset_list': eval_datasets,
            'restrict_classes': restrict_classes,
            'restrict_num_per_class': restrict_num_per_class,
            'checkpoint_dir': FLAGS.train_checkpoint_dir,
            'summary_dir': FLAGS.summary_dir,
            'records_root_dir': FLAGS.records_root_dir,
            'eval_finegrainedness': FLAGS.eval_finegrainedness,
            'eval_finegrainedness_split': FLAGS.eval_finegrainedness_split,
            'eval_imbalance_dataset': FLAGS.eval_imbalance_dataset,
            'omit_from_saving_and_reloading':
            FLAGS.omit_from_saving_and_reloading,
        }

        train_learner_class = gin.query_parameter(
            'Trainer.train_learner_class')
        if gin.query_parameter('Trainer.episodic'):
            trainer_instance = trainer.EpisodicTrainer(**trainer_kwargs)
            if train_learner_class not in trainer.EPISODIC_LEARNER_NAMES:
                raise ValueError(
                    'When "episodic" is True, "train_learner" should be an episodic '
                    'one, among {}, but received {}.'.format(
                        trainer.EPISODIC_LEARNER_NAMES, train_learner_class))
        else:
            trainer_instance = trainer.BatchTrainer(**trainer_kwargs)
            if train_learner_class not in trainer.BATCH_LEARNER_NAMES:
                raise ValueError(
                    'When `episodic` is False, `train_learner` should be a batch one, '
                    'among {}, but received {}.'.format(
                        trainer.BATCH_LEARNER_NAMES, train_learner_class))

    except ValueError as e:
        logging.info('Full Gin configurations:\n%s', gin.config_str())
        raise e

    # All configurable objects/functions should have been instantiated/called.
    logging.info('Operative Gin configurations:\n%s',
                 gin.operative_config_str())
    if FLAGS.is_training and FLAGS.train_checkpoint_dir:
        record_operative_gin_configurations(FLAGS.train_checkpoint_dir)
    # TODO(all) Improve saving of gin configs (during train and eval).
    # Above, handles training only and now below is a hack for evaluation.
    elif not FLAGS.is_training and FLAGS.summary_dir:
        record_operative_gin_configurations(FLAGS.summary_dir)

    datasets = train_datasets if FLAGS.is_training else eval_datasets
    logging.info('Starting %s for dataset(s) %s...',
                 'training' if FLAGS.is_training else 'evaluation', datasets)
    if FLAGS.is_training:
        trainer_instance.train()
    elif set(datasets).intersection(trainer.DATASETS_WITH_EXAMPLE_SPLITS):
        if not data.POOL_SUPPORTED:
            raise NotImplementedError(
                'Example-level splits or pools not supported.')
    else:
        if len(datasets) != 1:
            raise ValueError(
                'Requested datasets {} for evaluation, but evaluation '
                'should be performed on individual datasets '
                'only.'.format(datasets))

        eval_split = trainer.TEST_SPLIT
        if FLAGS.eval_finegrainedness:
            eval_split = FLAGS.eval_finegrainedness_split

        _, _, acc_summary, ci_acc_summary = trainer_instance.evaluate(
            eval_split)
        if trainer_instance.summary_writer:
            trainer_instance.summary_writer.add_summary(acc_summary)
            trainer_instance.summary_writer.add_summary(ci_acc_summary)

    # Flushes the event file to disk and closes the file.
    if trainer_instance.summary_writer:
        trainer_instance.summary_writer.close()