Example #1
0
    def test_revnet_creation(self, model_id):
        """Test creation of RevNet models."""
        network = backbones.RevNet(model_id=model_id,
                                   norm_momentum=0.99,
                                   norm_epsilon=1e-5)

        backbone_config = backbones_cfg.Backbone(
            type='revnet', revnet=backbones_cfg.RevNet(model_id=model_id))
        norm_activation_config = common_cfg.NormActivation(norm_momentum=0.99,
                                                           norm_epsilon=1e-5,
                                                           use_sync_bn=False)

        factory_network = factory.build_backbone(
            input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
            backbone_config=backbone_config,
            norm_activation_config=norm_activation_config)

        network_config = network.get_config()
        factory_network_config = factory_network.get_config()

        self.assertEqual(network_config, factory_network_config)
def image_classification_imagenet_revnet() -> cfg.ExperimentConfig:
    """Returns a revnet config for image classification on imagenet."""
    train_batch_size = 4096
    eval_batch_size = 4096
    steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size

    config = cfg.ExperimentConfig(
        task=ImageClassificationTask(
            model=ImageClassificationModel(
                num_classes=1001,
                input_size=[224, 224, 3],
                backbone=backbones.Backbone(
                    type='revnet', revnet=backbones.RevNet(model_id=56)),
                norm_activation=common.NormActivation(norm_momentum=0.9,
                                                      norm_epsilon=1e-5,
                                                      use_sync_bn=False),
                add_head_batch_norm=True),
            losses=Losses(l2_weight_decay=1e-4),
            train_data=DataConfig(input_path=os.path.join(
                IMAGENET_INPUT_PATH_BASE, 'train*'),
                                  is_training=True,
                                  global_batch_size=train_batch_size),
            validation_data=DataConfig(input_path=os.path.join(
                IMAGENET_INPUT_PATH_BASE, 'valid*'),
                                       is_training=False,
                                       global_batch_size=eval_batch_size)),
        trainer=cfg.TrainerConfig(
            steps_per_loop=steps_per_epoch,
            summary_interval=steps_per_epoch,
            checkpoint_interval=steps_per_epoch,
            train_steps=90 * steps_per_epoch,
            validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
            validation_interval=steps_per_epoch,
            optimizer_config=optimization.OptimizationConfig({
                'optimizer': {
                    'type': 'sgd',
                    'sgd': {
                        'momentum': 0.9
                    }
                },
                'learning_rate': {
                    'type': 'stepwise',
                    'stepwise': {
                        'boundaries': [
                            30 * steps_per_epoch, 60 * steps_per_epoch,
                            80 * steps_per_epoch
                        ],
                        'values': [0.8, 0.08, 0.008, 0.0008]
                    }
                },
                'warmup': {
                    'type': 'linear',
                    'linear': {
                        'warmup_steps': 5 * steps_per_epoch,
                        'warmup_learning_rate': 0
                    }
                }
            })),
        restrictions=[
            'task.train_data.is_training != None',
            'task.validation_data.is_training != None'
        ])

    return config