Exemplo n.º 1
0
def main(unused_argv):
    tf.logging.info('FLAGS.gin_config: %s', FLAGS.gin_config)
    tf.logging.info('FLAGS.gin_bindings: %s', FLAGS.gin_bindings)
    gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)

    learner_config = trainer.LearnerConfig()

    # Check for inconsistent or contradictory flag settings.
    if (learner_config.checkpoint_for_eval
            and learner_config.pretrained_checkpoint):
        raise ValueError(
            'Can not define both checkpoint_for_eval and '
            'pretrained_checkpoint. The difference between them is '
            'that in the former all variables are restored (including '
            'global step) whereas the latter is only applicable to '
            'the start of training for initializing the model from '
            'pre-trained weights. It is also only applicable to '
            'episodic models and restores only the embedding weights.')

    (train_datasets, eval_datasets,
     restrict_num_per_class) = trainer.get_datasets_and_restrictions()

    train_learner = None
    if FLAGS.is_training or (FLAGS.eval_finegrainedness
                             and FLAGS.eval_finegrainedness_split == 'train'):
        # If eval_finegrainedness is True, even in pure evaluation mode we still
        # require a train learner, since we may perform this analysis on the
        # training sub-graph of ImageNet too.
        train_learner = NAME_TO_LEARNER[learner_config.train_learner]
    eval_learner = NAME_TO_LEARNER[learner_config.eval_learner]

    # Get a trainer or evaluator.
    trainer_kwargs = {
        'train_learner': train_learner,
        'eval_learner': eval_learner,
        'is_training': FLAGS.is_training,
        'train_dataset_list': train_datasets,
        'eval_dataset_list': eval_datasets,
        'restrict_num_per_class': restrict_num_per_class,
        'checkpoint_dir': FLAGS.train_checkpoint_dir,
        'summary_dir': FLAGS.summary_dir,
        'records_root_dir': FLAGS.records_root_dir,
        'eval_finegrainedness': FLAGS.eval_finegrainedness,
        'eval_finegrainedness_split': FLAGS.eval_finegrainedness_split,
        'eval_imbalance_dataset': FLAGS.eval_imbalance_dataset,
        'omit_from_saving_and_reloading': FLAGS.omit_from_saving_and_reloading,
    }
    if learner_config.episodic:
        trainer_instance = trainer.EpisodicTrainer(**trainer_kwargs)
        if learner_config.train_learner not in EPISODIC_LEARNERS:
            raise ValueError(
                'When "episodic" is True, "train_learner" should be an episodic one, '
                'among {}.'.format(EPISODIC_LEARNERS))
    else:
        trainer_instance = trainer.BatchTrainer(**trainer_kwargs)
        if learner_config.train_learner not in BATCH_LEARNERS:
            raise ValueError(
                'When "episodic" is False, "train_learner" should be a batch one, '
                'among {}.'.format(BATCH_LEARNERS))

    mode = 'training' if FLAGS.is_training else 'evaluation'
    datasets = train_datasets if FLAGS.is_training else eval_datasets
    tf.logging.info('Starting %s for dataset(s) %s...' % (mode, datasets))

    # Record gin operative config string after the setup, both in the logs and in
    # the checkpoint directory.
    gin_operative_config = gin.operative_config_str()
    tf.logging.info('gin configuration:\n%s', gin_operative_config)
    if FLAGS.train_checkpoint_dir:
        gin_log_file = os.path.join(FLAGS.train_checkpoint_dir,
                                    'operative_config.gin')
        # If it exists already, rename it instead of overwriting it.
        # This just saves the previous one, not all the ones before.
        if tf.gfile.Exists(gin_log_file):
            tf.gfile.Rename(gin_log_file,
                            gin_log_file + '.old',
                            overwrite=True)
        with tf.io.gfile.GFile(gin_log_file, 'w') as f:
            f.write(gin_operative_config)

    if FLAGS.is_training:
        trainer_instance.train()
    elif set(datasets).intersection(trainer.DATASETS_WITH_EXAMPLE_SPLITS):
        if not data.POOL_SUPPORTED:
            raise NotImplementedError(
                'Example-level splits or pools not supported.')
    else:
        if len(datasets) != 1:
            raise ValueError(
                'Requested datasets {} for evaluation, but evaluation '
                'should be performed on individual datasets '
                'only.'.format(datasets))

        eval_split = 'test'
        if FLAGS.eval_finegrainedness:
            eval_split = FLAGS.eval_finegrainedness_split

        trainer_instance.evaluate(eval_split)

    # Flushes the event file to disk and closes the file.
    if trainer_instance.summary_writer:
        trainer_instance.summary_writer.close()
Exemplo n.º 2
0
  def test_episodic_trainer(self):
    # Inspired from `learn/gin/default/debug_proto_mini_imagenet.gin`, but
    # building the objects explicitly.
    learn_config = trainer.LearnConfig(
        num_updates=100,
        batch_size=8,  # unused
        num_eval_episodes=10,
        checkpoint_every=10,
        validate_every=5,
        log_every=1,
        transductive_batch_norm=False,
    )

    learner_config = trainer.LearnerConfig(
        episodic=True,
        train_learner='PrototypicalNet',
        eval_learner='PrototypicalNet',
        pretrained_checkpoint='',
        checkpoint_for_eval='',
        embedding_network='four_layer_convnet',
        learning_rate=1e-4,
        decay_learning_rate=True,
        decay_every=5000,
        decay_rate=0.5,
        experiment_name='test',
        pretrained_source='',
    )

    # PrototypicalNetworkLearner is built automatically and this test does not
    # have the opportunity to pass values to its constructor except through gin.
    gin.bind_parameter('PrototypicalNetworkLearner.weight_decay', 1e-4)

    # Values for EpisodeDescriptionSampler
    gin.bind_parameter('EpisodeDescriptionSampler.min_ways', 5)
    gin.bind_parameter('EpisodeDescriptionSampler.max_ways_upper_bound', 50)
    gin.bind_parameter('EpisodeDescriptionSampler.max_num_query', 10)
    gin.bind_parameter('EpisodeDescriptionSampler.max_support_set_size', 500)
    gin.bind_parameter(
        'EpisodeDescriptionSampler.max_support_size_contrib_per_class', 100)
    gin.bind_parameter('EpisodeDescriptionSampler.min_log_weight',
                       -0.69314718055994529)  # np.log(0.5)
    gin.bind_parameter('EpisodeDescriptionSampler.max_log_weight',
                       0.69314718055994529)  # np.log(2)

    data_config = config.DataConfig(
        image_height=84,
        shuffle_buffer_size=20,
        read_buffer_size_bytes=(1024**2),
        num_prefetch=2,
    )

    episodic_trainer = trainer.EpisodicTrainer(
        train_learner=learner.PrototypicalNetworkLearner,
        eval_learner=learner.PrototypicalNetworkLearner,
        is_training=True,
        dataset_list=['mini_imagenet'],
        checkpoint_dir='',
        summary_dir='',
        eval_finegrainedness=False,
        eval_finegrainedness_split='',
        eval_imbalance_dataset='',
        num_train_classes=None,
        num_test_classes=None,
        num_train_examples=None,
        num_test_examples=None,
        learn_config=learn_config,
        learner_config=learner_config,
        data_config=data_config,
    )

    # Get the next train / valid / test episodes.
    train_episode = episodic_trainer.get_next('train')
    self.assertIsInstance(train_episode, providers.EpisodeDataset)

    # This isn't really a test. It just checks that things don't crash...
    print(
        episodic_trainer.sess.run([
            episodic_trainer.train_op, episodic_trainer.losses['train'],
            episodic_trainer.accs['train']
        ]))
Exemplo n.º 3
0
def main(unused_argv):

    # Parse Gin configurations passed to this script.
    parse_cmdline_gin_configurations()

    # Try to reload a previously recorded Gin configuration.
    # TODO(eringrant): Allow querying of a value to be bound without actually
    # binding it to avoid the redundant call to `parse_cmdline_gin_configurations`
    # below.
    checkpoint_for_eval = gin.query_parameter(
        'LearnerConfig.checkpoint_for_eval')
    if checkpoint_for_eval and FLAGS.reload_eval_checkpoint_gin_config:
        eval_checkpoint_dir = os.path.dirname(checkpoint_for_eval)
        load_operative_gin_configurations(eval_checkpoint_dir)

        # Reload the command-line Gin configuration to allow overriding of the Gin
        # configuration loaded from the checkpoint directory.
        parse_cmdline_gin_configurations()

    # Wrap object instantiations to print out full Gin configuration on failure.
    try:
        learner_config = trainer.LearnerConfig()

        (train_datasets, eval_datasets, restrict_classes,
         restrict_num_per_class) = trainer.get_datasets_and_restrictions()

        train_learner = None
        if FLAGS.is_training or (FLAGS.eval_finegrainedness and
                                 FLAGS.eval_finegrainedness_split == 'train'):
            # If eval_finegrainedness is True, even in pure evaluation mode we still
            # require a train learner, since we may perform this analysis on the
            # training sub-graph of ImageNet too.
            train_learner = trainer.NAME_TO_LEARNER[
                learner_config.train_learner]
        eval_learner = trainer.NAME_TO_LEARNER[learner_config.eval_learner]

        # Get a trainer or evaluator.
        trainer_kwargs = {
            'train_learner': train_learner,
            'eval_learner': eval_learner,
            'is_training': FLAGS.is_training,
            'train_dataset_list': train_datasets,
            'eval_dataset_list': eval_datasets,
            'restrict_classes': restrict_classes,
            'restrict_num_per_class': restrict_num_per_class,
            'checkpoint_dir': FLAGS.train_checkpoint_dir,
            'summary_dir': FLAGS.summary_dir,
            'records_root_dir': FLAGS.records_root_dir,
            'eval_finegrainedness': FLAGS.eval_finegrainedness,
            'eval_finegrainedness_split': FLAGS.eval_finegrainedness_split,
            'eval_imbalance_dataset': FLAGS.eval_imbalance_dataset,
            'omit_from_saving_and_reloading':
            FLAGS.omit_from_saving_and_reloading,
        }
        if learner_config.episodic:
            trainer_instance = trainer.EpisodicTrainer(**trainer_kwargs)
            if learner_config.train_learner not in trainer.EPISODIC_LEARNER_NAMES:
                raise ValueError(
                    'When "episodic" is True, "train_learner" should be an episodic one, '
                    'among {}.'.format(trainer.EPISODIC_LEARNER_NAMES))
        else:
            trainer_instance = trainer.BatchTrainer(**trainer_kwargs)
            if learner_config.train_learner not in trainer.BATCH_LEARNER_NAMES:
                raise ValueError(
                    'When "episodic" is False, "train_learner" should be a batch one, '
                    'among {}.'.format(trainer.BATCH_LEARNER_NAMES))

    except ValueError as e:
        logging.info('Full Gin configurations:\n%s', gin.config_str())
        raise e

    # All configurable objects/functions should have been instantiated/called.
    logging.info('Operative Gin configurations:\n%s',
                 gin.operative_config_str())
    if FLAGS.is_training and FLAGS.train_checkpoint_dir:
        record_operative_gin_configurations(FLAGS.train_checkpoint_dir)

    datasets = train_datasets if FLAGS.is_training else eval_datasets
    logging.info('Starting %s for dataset(s) %s...',
                 'training' if FLAGS.is_training else 'evaluation', datasets)
    if FLAGS.is_training:
        trainer_instance.train()
    elif set(datasets).intersection(trainer.DATASETS_WITH_EXAMPLE_SPLITS):
        if not data.POOL_SUPPORTED:
            raise NotImplementedError(
                'Example-level splits or pools not supported.')
    else:
        if len(datasets) != 1:
            raise ValueError(
                'Requested datasets {} for evaluation, but evaluation '
                'should be performed on individual datasets '
                'only.'.format(datasets))

        eval_split = 'test'
        if FLAGS.eval_finegrainedness:
            eval_split = FLAGS.eval_finegrainedness_split

        trainer_instance.evaluate(eval_split)

    # Flushes the event file to disk and closes the file.
    if trainer_instance.summary_writer:
        trainer_instance.summary_writer.close()
Exemplo n.º 4
0
def main(unused_argv):

    # Parse Gin configurations passed to this script.
    parse_cmdline_gin_configurations()

    if FLAGS.reload_checkpoint_gin_config:
        # Try to reload a previously recorded Gin configuration.
        # TODO(eringrant): Allow querying of a value to be bound without binding it
        # to avoid the redundant call to `parse_cmdline_gin_configurations` below.
        try:
            checkpoint_to_reload = gin.query_parameter(
                'Trainer.checkpoint_for_eval')
        except ValueError:
            try:
                checkpoint_to_reload = gin.query_parameter(
                    'Trainer.pretrained_checkpoint')
            except ValueError:
                checkpoint_to_reload = None

        # Load the operative Gin configurations from the checkpoint directory.
        if checkpoint_to_reload:
            reload_checkpoint_dir = os.path.dirname(checkpoint_to_reload)
            load_operative_gin_configurations(reload_checkpoint_dir)

            # Reload the command-line Gin configuration to allow overriding of the Gin
            # configuration loaded from the checkpoint directory.
            parse_cmdline_gin_configurations()

    # Wrap object instantiations to print out full Gin configuration on failure.
    try:
        (train_datasets, eval_datasets, restrict_classes,
         restrict_num_per_class) = trainer.get_datasets_and_restrictions()

        # Get a trainer or evaluator.
        trainer_kwargs = {
            'is_training': FLAGS.is_training,
            'train_dataset_list': train_datasets,
            'eval_dataset_list': eval_datasets,
            'restrict_classes': restrict_classes,
            'restrict_num_per_class': restrict_num_per_class,
            'checkpoint_dir': FLAGS.train_checkpoint_dir,
            'summary_dir': FLAGS.summary_dir,
            'records_root_dir': FLAGS.records_root_dir,
            'eval_finegrainedness': FLAGS.eval_finegrainedness,
            'eval_finegrainedness_split': FLAGS.eval_finegrainedness_split,
            'eval_imbalance_dataset': FLAGS.eval_imbalance_dataset,
            'omit_from_saving_and_reloading':
            FLAGS.omit_from_saving_and_reloading,
        }

        train_learner_class = gin.query_parameter(
            'Trainer.train_learner_class')
        if gin.query_parameter('Trainer.episodic'):
            trainer_instance = trainer.EpisodicTrainer(**trainer_kwargs)
            if train_learner_class not in trainer.EPISODIC_LEARNER_NAMES:
                raise ValueError(
                    'When "episodic" is True, "train_learner" should be an episodic '
                    'one, among {}, but received {}.'.format(
                        trainer.EPISODIC_LEARNER_NAMES, train_learner_class))
        else:
            trainer_instance = trainer.BatchTrainer(**trainer_kwargs)
            if train_learner_class not in trainer.BATCH_LEARNER_NAMES:
                raise ValueError(
                    'When `episodic` is False, `train_learner` should be a batch one, '
                    'among {}, but received {}.'.format(
                        trainer.BATCH_LEARNER_NAMES, train_learner_class))

    except ValueError as e:
        logging.info('Full Gin configurations:\n%s', gin.config_str())
        raise e

    # All configurable objects/functions should have been instantiated/called.
    logging.info('Operative Gin configurations:\n%s',
                 gin.operative_config_str())
    if FLAGS.is_training and FLAGS.train_checkpoint_dir:
        record_operative_gin_configurations(FLAGS.train_checkpoint_dir)
    # TODO(all) Improve saving of gin configs (during train and eval).
    # Above, handles training only and now below is a hack for evaluation.
    elif not FLAGS.is_training and FLAGS.summary_dir:
        record_operative_gin_configurations(FLAGS.summary_dir)

    datasets = train_datasets if FLAGS.is_training else eval_datasets
    logging.info('Starting %s for dataset(s) %s...',
                 'training' if FLAGS.is_training else 'evaluation', datasets)
    if FLAGS.is_training:
        trainer_instance.train()
    elif set(datasets).intersection(trainer.DATASETS_WITH_EXAMPLE_SPLITS):
        if not data.POOL_SUPPORTED:
            raise NotImplementedError(
                'Example-level splits or pools not supported.')
    else:
        if len(datasets) != 1:
            raise ValueError(
                'Requested datasets {} for evaluation, but evaluation '
                'should be performed on individual datasets '
                'only.'.format(datasets))

        eval_split = trainer.TEST_SPLIT
        if FLAGS.eval_finegrainedness:
            eval_split = FLAGS.eval_finegrainedness_split

        _, _, acc_summary, ci_acc_summary = trainer_instance.evaluate(
            eval_split)
        if trainer_instance.summary_writer:
            trainer_instance.summary_writer.add_summary(acc_summary)
            trainer_instance.summary_writer.add_summary(ci_acc_summary)

    # Flushes the event file to disk and closes the file.
    if trainer_instance.summary_writer:
        trainer_instance.summary_writer.close()
Exemplo n.º 5
0
def main(unused_argv):
    gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)

    learner_config = trainer.LearnerConfig()

    # Check for inconsistent or contradictory flag settings.
    if (learner_config.checkpoint_for_eval
            and learner_config.pretrained_checkpoint):
        raise ValueError(
            'Can not define both checkpoint_for_eval and '
            'pretrained_checkpoint. The difference between them is '
            'that in the former all variables are restored (including '
            'global step) whereas the latter is only applicable to '
            'the start of training for initializing the model from '
            'pre-trained weights. It is also only applicable to '
            'episodic models and restores only the embedding weights.')

    datasets = get_datasets()

    train_learner = None
    if FLAGS.is_training or (FLAGS.eval_finegrainedness
                             and FLAGS.eval_finegrainedness_split == 'train'):
        # If eval_finegrainedness is True, even in pure evaluation mode we still
        # require a train learner, since we may perform this analysis on the
        # training sub-graph of ImageNet too.
        train_learner = NAME_TO_LEARNER[learner_config.train_learner]
    eval_learner = NAME_TO_LEARNER[learner_config.eval_learner]

    # Get a trainer or evaluator.
    if learner_config.episodic:
        trainer_instance = trainer.EpisodicTrainer(
            train_learner, eval_learner, FLAGS.is_training, datasets,
            FLAGS.train_checkpoint_dir, FLAGS.summary_dir,
            FLAGS.eval_finegrainedness, FLAGS.eval_finegrainedness_split,
            FLAGS.eval_imbalance_dataset)
        if learner_config.train_learner not in EPISODIC_LEARNERS:
            raise ValueError(
                'When "episodic" is True, "train_learner" should be an episodic one, '
                'among {}.'.format(EPISODIC_LEARNERS))
    else:
        trainer_instance = trainer.BatchTrainer(
            train_learner, eval_learner, FLAGS.is_training, datasets,
            FLAGS.train_checkpoint_dir, FLAGS.summary_dir,
            FLAGS.eval_finegrainedness, FLAGS.eval_finegrainedness_split,
            FLAGS.eval_imbalance_dataset)
        if learner_config.train_learner not in BATCH_LEARNERS:
            raise ValueError(
                'When "episodic" is False, "train_learner" should be a batch one, '
                'among {}.'.format(BATCH_LEARNERS))

    mode = 'training' if FLAGS.is_training else 'evaluation'
    tf.logging.info('Starting %s for dataset(s) %s...' % (mode, datasets))

    if FLAGS.is_training:
        trainer_instance.train()
    elif set(datasets).intersection(trainer.DATASETS_WITH_EXAMPLE_SPLITS):
        if not data.POOL_SUPPORTED:
            raise NotImplementedError(
                'Example-level splits or pools not supported.')
    else:
        if len(datasets) != 1:
            raise ValueError(
                'Requested datasets {} for evaluation, but evaluation '
                'should be performed on individual datasets '
                'only.'.format(datasets))

        eval_split = 'test'
        if FLAGS.eval_finegrainedness:
            eval_split = FLAGS.eval_finegrainedness_split

        trainer_instance.evaluate(eval_split)

    # Flushes the event file to disk and closes the file.
    trainer_instance.summary_writer.close()
Exemplo n.º 6
0
    def test_episodic_trainer(self):
        # Inspired from `learn/gin/default/debug_proto_mini_imagenet.gin`, but
        # building the objects explicitly.
        learn_config = trainer.LearnConfig(
            num_updates=100,
            batch_size=8,  # unused
            num_eval_episodes=10,
            checkpoint_every=10,
            validate_every=5,
            log_every=1,
            transductive_batch_norm=False,
        )

        learner_config = trainer.LearnerConfig(
            episodic=True,
            train_learner='PrototypicalNet',
            eval_learner='PrototypicalNet',
            pretrained_checkpoint='',
            checkpoint_for_eval='',
            embedding_network='four_layer_convnet',
            learning_rate=1e-4,
            decay_learning_rate=True,
            decay_every=5000,
            decay_rate=0.5,
            experiment_name='test',
            pretrained_source='',
        )

        # PrototypicalNetworkLearner is built automatically and this test does not
        # have the opportunity to pass values to its constructor except through gin.
        gin.bind_parameter('PrototypicalNetworkLearner.weight_decay', 1e-4)

        # Values that can't be passed directly to EpisodeDescriptionConfig
        gin.bind_parameter('process_episode.support_decoder',
                           decoder.ImageDecoder())
        gin.bind_parameter('process_episode.query_decoder',
                           decoder.ImageDecoder())

        episode_config = config.EpisodeDescriptionConfig(
            num_ways=None,
            num_support=None,
            num_query=None,
            min_ways=5,
            max_ways_upper_bound=50,
            max_num_query=10,
            max_support_set_size=500,
            max_support_size_contrib_per_class=100,
            min_log_weight=np.log(0.5),
            max_log_weight=np.log(2),
            ignore_dag_ontology=False,
            ignore_bilevel_ontology=False)

        data_config = config.DataConfig(
            image_height=84,
            shuffle_buffer_size=20,
            read_buffer_size_bytes=(1024**2),
            num_prefetch=2,
        )

        episodic_trainer = trainer.EpisodicTrainer(
            train_learner=learner.PrototypicalNetworkLearner,
            eval_learner=learner.PrototypicalNetworkLearner,
            is_training=True,
            train_dataset_list=['mini_imagenet'],
            eval_dataset_list=['mini_imagenet'],
            restrict_classes={},
            restrict_num_per_class={},
            checkpoint_dir='',
            summary_dir='',
            records_root_dir=FLAGS.records_root_dir,
            eval_finegrainedness=False,
            eval_finegrainedness_split='',
            eval_imbalance_dataset='',
            omit_from_saving_and_reloading='',
            train_episode_config=episode_config,
            eval_episode_config=episode_config,
            learn_config=learn_config,
            learner_config=learner_config,
            data_config=data_config,
        )

        # Get the next train / valid / test episodes.
        train_episode = episodic_trainer.get_next(trainer.TRAIN_SPLIT)
        self.assertIsInstance(train_episode, providers.EpisodeDataset)

        # This isn't really a test. It just checks that things don't crash...
        print(
            episodic_trainer.sess.run([
                episodic_trainer.train_op,
                episodic_trainer.losses[trainer.TRAIN_SPLIT],
                episodic_trainer.accs[trainer.TRAIN_SPLIT]
            ]))