Пример #1
0
  def test_episodic_overfit(self,
                            learner_class,
                            learner_config,
                            threshold=1.,
                            attempts=1):
    """Test that error goes down when training on a single episode.

    This can help check that the trained model and the evaluated one share
    the trainable parameters correctly.

    Args:
      learner_class: A subclass of Learner.
      learner_config: A string, the Learner-specific gin configuration string.
      threshold: A float (default 1.), the performance to reach at least once.
      attempts: An int (default 1), how many of the last steps should be checked
        when looking for a validation value reaching the threshold (default 1).
    """
    gin_config = '\n'.join((self.BASE_GIN_CONFIG, learner_config))
    gin.parse_config(gin_config)

    episode_config = config.EpisodeDescriptionConfig(
        num_ways=self.NUM_EXAMPLES, num_support=1, num_query=1)

    trainer_instance = trainer.Trainer(
        train_learner_class=learner_class,
        eval_learner_class=learner_class,
        is_training=True,
        train_dataset_list=['dummy'],
        eval_dataset_list=['dummy'],
        records_root_dir=self.temp_dir,
        checkpoint_dir=os.path.join(self.temp_dir, 'checkpoints'),
        train_episode_config=episode_config,
        eval_episode_config=episode_config,
        data_config=config.DataConfig(),
        # BEGIN GOOGLE_INTERNAL
        real_episodes=False,
        real_episodes_results_dir='',
        # END GOOGLE_INTERNAL
    )
    # Train 1 update at a time for the last `attempts - 1` steps.
    trainer_instance.num_updates -= (attempts - 1)
    trainer_instance.train()
    valid_accs = [trainer_instance.valid_acc]
    for _ in range(attempts - 1):
      trainer_instance.num_updates += 1
      trainer_instance.train()
      valid_accs.append(trainer_instance.valid_acc)
    self.assertGreaterEqual(max(valid_accs), threshold)
Пример #2
0
    def test_trainer(self):
        # PrototypicalNetworkLearner is built automatically and this test does not
        # have the opportunity to pass values to its constructor except through gin.
        gin.bind_parameter('PrototypicalNetworkLearner.weight_decay', 1e-4)
        gin.bind_parameter(
            'PrototypicalNetworkLearner.backprop_through_moments', True)
        gin.bind_parameter(
            'PrototypicalNetworkLearner.transductive_batch_norm', False)
        gin.bind_parameter('PrototypicalNetworkLearner.embedding_fn',
                           'four_layer_convnet')

        # Values that can't be passed directly to EpisodeDescriptionConfig
        gin.bind_parameter('process_episode.support_decoder',
                           decoder.ImageDecoder())
        gin.bind_parameter('process_episode.query_decoder',
                           decoder.ImageDecoder())

        episode_config = config.EpisodeDescriptionConfig(
            num_ways=None,
            num_support=None,
            num_query=None,
            min_ways=5,
            max_ways_upper_bound=50,
            max_num_query=10,
            max_support_set_size=500,
            max_support_size_contrib_per_class=100,
            min_log_weight=np.log(0.5),
            max_log_weight=np.log(2),
            ignore_dag_ontology=False,
            ignore_bilevel_ontology=False)

        # Inspired from `learn/gin/default/debug_proto_mini_imagenet.gin`, but
        # building the objects explicitly.
        data_config = config.DataConfig(
            image_height=84,
            shuffle_buffer_size=20,
            read_buffer_size_bytes=(1024**2),
            num_prefetch=2,
        )

        trainer_instance = trainer.Trainer(
            train_learner_class=learner_lib.PrototypicalNetworkLearner,
            eval_learner_class=learner_lib.PrototypicalNetworkLearner,
            is_training=True,
            train_dataset_list=['mini_imagenet'],
            eval_dataset_list=['mini_imagenet'],
            restrict_classes={},
            restrict_num_per_class={},
            checkpoint_dir='',
            summary_dir='',
            records_root_dir=FLAGS.records_root_dir,
            eval_split=trainer.VALID_SPLIT,
            eval_finegrainedness=False,
            eval_finegrainedness_split='',
            eval_imbalance_dataset='',
            omit_from_saving_and_reloading='',
            train_episode_config=episode_config,
            eval_episode_config=episode_config,
            data_config=data_config,
            num_updates=100,
            batch_size=8,  # unused
            num_eval_episodes=10,
            checkpoint_every=10,
            validate_every=5,
            log_every=1,
            checkpoint_to_restore=None,
            learning_rate=1e-4,
            decay_learning_rate=True,
            decay_every=5000,
            decay_rate=0.5,
            experiment_name='test',
            pretrained_source='',
        )

        # Get the next train / valid / test episodes.
        train_episode = trainer_instance.next_data[trainer.TRAIN_SPLIT]
        self.assertIsInstance(train_episode, providers.EpisodeDataset)

        # This isn't really a test. It just checks that things don't crash...
        print(
            trainer_instance.sess.run([
                trainer_instance.train_op,
                trainer_instance.losses[trainer.TRAIN_SPLIT],
                trainer_instance.accuracies[trainer.TRAIN_SPLIT]
            ]))
Пример #3
0
def main(unused_argv):

    # Parse Gin configurations passed to this script.
    parse_cmdline_gin_configurations()

    if FLAGS.reload_checkpoint_gin_config:
        # Try to reload a previously recorded Gin configuration from an operative
        # Gin configuration file in one of the provided directories.
        # TODO(eringrant): Allow querying of a value to be bound without binding it
        # to avoid the redundant call to `parse_cmdline_gin_configurations` below.
        try:
            checkpoint_to_restore = gin.query_parameter(
                'Trainer.checkpoint_to_restore')
        except ValueError:
            checkpoint_to_restore = None

        # Load the operative Gin configurations from the checkpoint directory.
        if checkpoint_to_restore:
            restore_checkpoint_dir = os.path.dirname(checkpoint_to_restore)
            load_operative_gin_configurations(restore_checkpoint_dir)

            # Reload the command-line Gin configuration to allow overriding of the Gin
            # configuration loaded from the checkpoint directory.
            parse_cmdline_gin_configurations()

    # Wrap object instantiations to print out full Gin configuration on failure.
    try:
        (train_datasets, eval_datasets, restrict_classes,
         restrict_num_per_class) = trainer.get_datasets_and_restrictions()

        # Get a trainer or evaluator.
        trainer_instance = trainer.Trainer(
            is_training=FLAGS.is_training,
            train_dataset_list=train_datasets,
            eval_dataset_list=eval_datasets,
            restrict_classes=restrict_classes,
            restrict_num_per_class=restrict_num_per_class,
            checkpoint_dir=FLAGS.train_checkpoint_dir,
            summary_dir=FLAGS.summary_dir,
            records_root_dir=FLAGS.records_root_dir,
            eval_finegrainedness=FLAGS.eval_finegrainedness,
            eval_finegrainedness_split=FLAGS.eval_finegrainedness_split,
            eval_imbalance_dataset=FLAGS.eval_imbalance_dataset,
            omit_from_saving_and_reloading=FLAGS.
            omit_from_saving_and_reloading,
            eval_split=FLAGS.eval_split,
        )
    except ValueError as e:
        logging.info('Full Gin configurations:\n%s', gin.config_str())
        raise e

    # All configurable objects/functions should have been instantiated/called.
    # TODO(evcu): Tie saving of Gin configuration at training and evaluation time.
    logging.info('Operative Gin configurations:\n%s',
                 gin.operative_config_str())
    if FLAGS.is_training and FLAGS.train_checkpoint_dir:
        record_operative_gin_configurations(FLAGS.train_checkpoint_dir)
    elif not FLAGS.is_training and FLAGS.summary_dir:
        record_operative_gin_configurations(FLAGS.summary_dir)

    datasets = train_datasets if FLAGS.is_training else eval_datasets
    logging.info('Starting %s for dataset(s) %s...',
                 'training' if FLAGS.is_training else 'evaluation', datasets)
    if FLAGS.is_training:
        trainer_instance.train()
    elif set(datasets).intersection(trainer.DATASETS_WITH_EXAMPLE_SPLITS):
        if not data.POOL_SUPPORTED:
            raise NotImplementedError(
                'Example-level splits or pools not supported.')
    else:
        if len(datasets) != 1:
            raise ValueError(
                'Requested datasets {} for evaluation, but evaluation '
                'should be performed on individual datasets '
                'only.'.format(datasets))

        if FLAGS.eval_finegrainedness:
            eval_split = FLAGS.eval_finegrainedness_split
        elif FLAGS.eval_split:
            eval_split = FLAGS.eval_split
        else:
            eval_split = trainer.TEST_SPLIT

        _, _, acc_summary, ci_acc_summary = trainer_instance.evaluate(
            eval_split)
        if trainer_instance.summary_writer:
            trainer_instance.summary_writer.add_summary(acc_summary)
            trainer_instance.summary_writer.add_summary(ci_acc_summary)

    # Flushes the event file to disk and closes the file.
    if trainer_instance.summary_writer:
        trainer_instance.summary_writer.close()